-
Notifications
You must be signed in to change notification settings - Fork 0
/
freeze.py
75 lines (58 loc) · 2.29 KB
/
freeze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
import argparse
import subprocess
from pathlib import Path
import tensorflow as tf
def analyze_inputs_outputs(graph):
ops = graph.get_operations()
outputs_set = set(ops)
inputs = []
for op in ops:
if len(op.inputs) == 0 and op.type != 'Const':
inputs.append(op)
else:
for input_tensor in op.inputs:
if input_tensor.op in outputs_set:
outputs_set.remove(input_tensor.op)
return inputs, list(outputs_set)
def freeze_graph(directory_path):
root_path = Path(directory_path)
meta_path = root_path / 'my_model.ckpt.meta'
frozen_path = root_path / 'frozen_graph.pb'
output_node_names = ['root/Sigmoid'] # Output nodes
tf.compat.v1.disable_v2_behavior()
device_name = "/cpu:0"
with tf.device(device_name):
with tf.compat.v1.Session() as sess:
# Restore the graph
saver = tf.compat.v1.train.import_meta_graph(str(meta_path), clear_devices=True)
# Load weights
latest_checkpoint_path = tf.compat.v1.train.latest_checkpoint(str(root_path))
print(f'Tensorflow reading {latest_checkpoint_path} before freezing')
saver.restore(sess, latest_checkpoint_path)
# Freeze the graph
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names)
# Save the frozen graph
with open(frozen_path, 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
def run_console_tool(tool_arguments):
python_executable = Path.cwd() / 'venv' / 'bin' / 'python'
options = [
str(python_executable), __file__,
*tool_arguments
]
print('[SUBPROCESS] {}'.format(' '.join(options)))
if sys.version_info.major <= 6:
return subprocess.run(options)
else:
return subprocess.run(options, capture_output=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_dir", default='./trained_models/binary_sum_v1/', type=str,
help="Checkpoint model file to import")
args = parser.parse_args()
checkpoints_path = Path(args.checkpoint_dir)
freeze_graph(checkpoints_path)