-
Notifications
You must be signed in to change notification settings - Fork 7
/
checkpoint_averaging.py
125 lines (93 loc) · 4.15 KB
/
checkpoint_averaging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import operator
import os
import numpy as np
import tensorflow as tf
def parseargs():
msg = "Average checkpoints"
usage = "average.py [<args>] [-h | --help]"
parser = argparse.ArgumentParser(description=msg, usage=usage)
parser.add_argument("--path", type=str, required=True,
help="checkpoint dir")
parser.add_argument("--checkpoints", type=int, required=True,
help="number of checkpoints to use")
parser.add_argument("--output", type=str, help="output path")
parser.add_argument("--gpu", type=int, default=0,
help="the default gpu device index")
return parser.parse_args()
def get_checkpoints(path):
if not tf.gfile.Exists(os.path.join(path, "checkpoint")):
raise ValueError("Cannot find checkpoints in %s" % path)
checkpoint_names = []
with tf.gfile.GFile(os.path.join(path, "checkpoint")) as fd:
# Skip the first line
fd.readline()
for line in fd:
name = line.strip().split(":")[-1].strip()[1:-1]
key = int(name.split("-")[-1])
checkpoint_names.append((key, os.path.join(path, name)))
sorted_names = sorted(checkpoint_names, key=operator.itemgetter(0),
reverse=True)
return [item[-1] for item in sorted_names]
def checkpoint_exists(path):
return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or
tf.gfile.Exists(path + ".index"))
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
checkpoints = get_checkpoints(FLAGS.path)
checkpoints = checkpoints[:FLAGS.checkpoints]
if not checkpoints:
raise ValueError("No checkpoints provided for averaging.")
checkpoints = [c for c in checkpoints if checkpoint_exists(c)]
if not checkpoints:
raise ValueError(
"None of the provided checkpoints exist. %s" % FLAGS.checkpoints
)
var_list = tf.contrib.framework.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if not name.startswith("global_step"):
var_values[name] = np.zeros(shape)
for checkpoint in checkpoints:
reader = tf.contrib.framework.load_checkpoint(checkpoint)
for name in var_values:
tensor = reader.get_tensor(name)
var_dtypes[name] = tensor.dtype
var_values[name] += tensor
tf.logging.info("Read from checkpoint %s", checkpoint)
# Average checkpoints
for name in var_values:
var_values[name] /= len(checkpoints)
tf_vars = [
tf.get_variable(name, shape=var_values[name].shape,
dtype=var_dtypes[name]) for name in var_values
]
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
global_step = tf.Variable(0, name="global_step", trainable=False,
dtype=tf.int64)
saver = tf.train.Saver(tf.global_variables())
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True
sess_config.gpu_options.visible_device_list = "%s" % FLAGS.gpu
with tf.Session(config=sess_config) as sess:
sess.run(tf.global_variables_initializer())
for p, assign_op, (name, value) in zip(placeholders, assign_ops,
var_values.items()):
sess.run(assign_op, {p: value})
saved_name = os.path.join(FLAGS.output, "average")
saver.save(sess, saved_name, global_step=global_step)
tf.logging.info("Averaged checkpoints saved in %s", saved_name)
params_pattern = os.path.join(FLAGS.path, "*.json")
params_files = tf.gfile.Glob(params_pattern)
for name in params_files:
new_name = name.replace(FLAGS.path.rstrip("/"),
FLAGS.output.rstrip("/"))
tf.gfile.Copy(name, new_name, overwrite=True)
if __name__ == "__main__":
FLAGS = parseargs()
tf.app.run()