Skip to content

Commit 50e194e

Browse files
guoyejungrandao
authored andcommittedJul 1, 2019
tools/python: add script to convert TensorFlow model (.pb) to native model (.model)
For example, given TensorFlow model file espcn.pb, to generate native model file espcn.model, just run: python convert.py espcn.pb In current implementation, the native model file is generated for specific dnn network with hard-code python scripts maintained out of ffmpeg. For example, srcnn network used by vf_sr is generated with https://github.com/HighVoltageRocknRoll/sr/blob/master/generate_header_and_model.py#L85 In this patch, the script is designed as a general solution which converts general TensorFlow model .pb file into .model file. The script now has some tricky to be compatible with current implemention, will be refined step by step. The script is also added into ffmpeg source tree. It is expected there will be many more patches and community needs the ownership of it. Another technical direction is to do the conversion in c/c++ code within ffmpeg source tree. While .pb file is organized with protocol buffers, it is not easy to do such work with tiny c/c++ code, see more discussion at http://ffmpeg.org/pipermail/ffmpeg-devel/2019-May/244496.html. So, choose the python script. Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
1 parent 4877b58 commit 50e194e

File tree

3 files changed

+254
-0
lines changed

3 files changed

+254
-0
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@
3636
/lcov/
3737
/src
3838
/mapfile
39+
/tools/python/__pycache__/

‎tools/python/convert.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2019 Guo Yejun
2+
#
3+
# This file is part of FFmpeg.
4+
#
5+
# FFmpeg is free software; you can redistribute it and/or
6+
# modify it under the terms of the GNU Lesser General Public
7+
# License as published by the Free Software Foundation; either
8+
# version 2.1 of the License, or (at your option) any later version.
9+
#
10+
# FFmpeg is distributed in the hope that it will be useful,
11+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13+
# Lesser General Public License for more details.
14+
#
15+
# You should have received a copy of the GNU Lesser General Public
16+
# License along with FFmpeg; if not, write to the Free Software
17+
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18+
# ==============================================================================
19+
20+
# verified with Python 3.5.2 on Ubuntu 16.04
21+
import argparse
22+
import os
23+
from convert_from_tensorflow import *
24+
25+
def get_arguments():
26+
parser = argparse.ArgumentParser(description='generate native mode model with weights from deep learning model')
27+
parser.add_argument('--outdir', type=str, default='./', help='where to put generated files')
28+
parser.add_argument('--infmt', type=str, default='tensorflow', help='format of the deep learning model')
29+
parser.add_argument('infile', help='path to the deep learning model with weights')
30+
31+
return parser.parse_args()
32+
33+
def main():
34+
args = get_arguments()
35+
36+
if not os.path.isfile(args.infile):
37+
print('the specified input file %s does not exist' % args.infile)
38+
exit(1)
39+
40+
if not os.path.exists(args.outdir):
41+
print('create output directory %s' % args.outdir)
42+
os.mkdir(args.outdir)
43+
44+
basefile = os.path.split(args.infile)[1]
45+
basefile = os.path.splitext(basefile)[0]
46+
outfile = os.path.join(args.outdir, basefile) + '.model'
47+
48+
if args.infmt == 'tensorflow':
49+
convert_from_tensorflow(args.infile, outfile)
50+
51+
if __name__ == '__main__':
52+
main()
+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) 2019 Guo Yejun
2+
#
3+
# This file is part of FFmpeg.
4+
#
5+
# FFmpeg is free software; you can redistribute it and/or
6+
# modify it under the terms of the GNU Lesser General Public
7+
# License as published by the Free Software Foundation; either
8+
# version 2.1 of the License, or (at your option) any later version.
9+
#
10+
# FFmpeg is distributed in the hope that it will be useful,
11+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13+
# Lesser General Public License for more details.
14+
#
15+
# You should have received a copy of the GNU Lesser General Public
16+
# License along with FFmpeg; if not, write to the Free Software
17+
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18+
# ==============================================================================
19+
20+
import tensorflow as tf
21+
import numpy as np
22+
import sys, struct
23+
24+
__all__ = ['convert_from_tensorflow']
25+
26+
# as the first step to be compatible with vf_sr, it is not general.
27+
# it will be refined step by step.
28+
29+
class TFConverter:
30+
def __init__(self, graph_def, nodes, outfile):
31+
self.graph_def = graph_def
32+
self.nodes = nodes
33+
self.outfile = outfile
34+
self.layer_number = 0
35+
self.output_names = []
36+
self.name_node_dict = {}
37+
self.edges = {}
38+
self.conv_activations = {'Relu':0, 'Tanh':1, 'Sigmoid':2, 'LeakyRelu':4}
39+
self.conv_paddings = {'VALID':2, 'SAME':1}
40+
self.converted_nodes = set()
41+
self.op2code = {'Conv2D':1, 'DepthToSpace':2}
42+
43+
44+
def dump_for_tensorboard(self):
45+
graph = tf.get_default_graph()
46+
tf.import_graph_def(self.graph_def, name="")
47+
# tensorboard --logdir=/tmp/graph
48+
tf.summary.FileWriter('/tmp/graph', graph)
49+
50+
51+
def get_conv2d_params(self, node):
52+
knode = self.name_node_dict[node.input[1]]
53+
bnode = None
54+
activation = 'None'
55+
next = self.edges[node.name][0]
56+
if next.op == 'BiasAdd':
57+
self.converted_nodes.add(next.name)
58+
bnode = self.name_node_dict[next.input[1]]
59+
next = self.edges[next.name][0]
60+
if next.op in self.conv_activations:
61+
self.converted_nodes.add(next.name)
62+
activation = next.op
63+
return knode, bnode, activation
64+
65+
66+
def dump_conv2d_to_file(self, node, f):
67+
assert(node.op == 'Conv2D')
68+
self.layer_number = self.layer_number + 1
69+
self.converted_nodes.add(node.name)
70+
knode, bnode, activation = self.get_conv2d_params(node)
71+
72+
dilation = node.attr['dilations'].list.i[0]
73+
padding = node.attr['padding'].s
74+
padding = self.conv_paddings[padding.decode("utf-8")]
75+
76+
ktensor = knode.attr['value'].tensor
77+
filter_height = ktensor.tensor_shape.dim[0].size
78+
filter_width = ktensor.tensor_shape.dim[1].size
79+
in_channels = ktensor.tensor_shape.dim[2].size
80+
out_channels = ktensor.tensor_shape.dim[3].size
81+
kernel = np.frombuffer(ktensor.tensor_content, dtype=np.float32)
82+
kernel = kernel.reshape(filter_height, filter_width, in_channels, out_channels)
83+
kernel = np.transpose(kernel, [3, 0, 1, 2])
84+
85+
np.array([self.op2code[node.op], dilation, padding, self.conv_activations[activation], in_channels, out_channels, filter_height], dtype=np.uint32).tofile(f)
86+
kernel.tofile(f)
87+
88+
btensor = bnode.attr['value'].tensor
89+
if btensor.tensor_shape.dim[0].size == 1:
90+
bias = struct.pack("f", btensor.float_val[0])
91+
else:
92+
bias = btensor.tensor_content
93+
f.write(bias)
94+
95+
96+
def dump_depth2space_to_file(self, node, f):
97+
assert(node.op == 'DepthToSpace')
98+
self.layer_number = self.layer_number + 1
99+
block_size = node.attr['block_size'].i
100+
np.array([self.op2code[node.op], block_size], dtype=np.uint32).tofile(f)
101+
self.converted_nodes.add(node.name)
102+
103+
104+
def generate_layer_number(self):
105+
# in current hard code implementation, the layer number is the first data written to the native model file
106+
# it is not easy to know it at the beginning time in the general converter, so first do a dry run for compatibility
107+
# will be refined later.
108+
with open('/tmp/tmp.model', 'wb') as f:
109+
self.dump_layers_to_file(f)
110+
self.converted_nodes.clear()
111+
112+
113+
def dump_layers_to_file(self, f):
114+
for node in self.nodes:
115+
if node.name in self.converted_nodes:
116+
continue
117+
if node.op == 'Conv2D':
118+
self.dump_conv2d_to_file(node, f)
119+
elif node.op == 'DepthToSpace':
120+
self.dump_depth2space_to_file(node, f)
121+
122+
123+
def dump_to_file(self):
124+
self.generate_layer_number()
125+
with open(self.outfile, 'wb') as f:
126+
np.array([self.layer_number], dtype=np.uint32).tofile(f)
127+
self.dump_layers_to_file(f)
128+
129+
130+
def generate_name_node_dict(self):
131+
for node in self.nodes:
132+
self.name_node_dict[node.name] = node
133+
134+
135+
def generate_output_names(self):
136+
used_names = []
137+
for node in self.nodes:
138+
for input in node.input:
139+
used_names.append(input)
140+
141+
for node in self.nodes:
142+
if node.name not in used_names:
143+
self.output_names.append(node.name)
144+
145+
146+
def remove_identity(self):
147+
id_nodes = []
148+
id_dict = {}
149+
for node in self.nodes:
150+
if node.op == 'Identity':
151+
name = node.name
152+
input = node.input[0]
153+
id_nodes.append(node)
154+
# do not change the output name
155+
if name in self.output_names:
156+
self.name_node_dict[input].name = name
157+
self.name_node_dict[name] = self.name_node_dict[input]
158+
del self.name_node_dict[input]
159+
else:
160+
id_dict[name] = input
161+
162+
for idnode in id_nodes:
163+
self.nodes.remove(idnode)
164+
165+
for node in self.nodes:
166+
for i in range(len(node.input)):
167+
input = node.input[i]
168+
if input in id_dict:
169+
node.input[i] = id_dict[input]
170+
171+
172+
def generate_edges(self):
173+
for node in self.nodes:
174+
for input in node.input:
175+
if input in self.edges:
176+
self.edges[input].append(node)
177+
else:
178+
self.edges[input] = [node]
179+
180+
181+
def run(self):
182+
self.generate_name_node_dict()
183+
self.generate_output_names()
184+
self.remove_identity()
185+
self.generate_edges()
186+
187+
#check the graph with tensorboard with human eyes
188+
#self.dump_for_tensorboard()
189+
190+
self.dump_to_file()
191+
192+
193+
def convert_from_tensorflow(infile, outfile):
194+
with open(infile, 'rb') as f:
195+
# read the file in .proto format
196+
graph_def = tf.GraphDef()
197+
graph_def.ParseFromString(f.read())
198+
nodes = graph_def.node
199+
200+
converter = TFConverter(graph_def, nodes, outfile)
201+
converter.run()

0 commit comments

Comments
 (0)
Please sign in to comment.