Skip to content

Commit 4702de2

Browse files
authoredOct 27, 2017
Use FLAGS in main functions only + Updates to shuffling (tensorflow#2601)
1 parent edcd29f commit 4702de2

7 files changed

+111
-87
lines changed
 

‎official/mnist/convert_to_records.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def _bytes_feature(value):
5050
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
5151

5252

53-
def convert_to(data_set, name):
53+
def convert_to(dataset, name, directory):
5454
"""Converts a dataset to TFRecords."""
55-
images = data_set.images
56-
labels = data_set.labels
57-
num_examples = data_set.num_examples
55+
images = dataset.images
56+
labels = dataset.labels
57+
num_examples = dataset.num_examples
5858

5959
if images.shape[0] != num_examples:
6060
raise ValueError('Images size %d does not match label size %d.' %
@@ -63,7 +63,7 @@ def convert_to(data_set, name):
6363
cols = images.shape[2]
6464
depth = images.shape[3]
6565

66-
filename = os.path.join(FLAGS.directory, name + '.tfrecords')
66+
filename = os.path.join(directory, name + '.tfrecords')
6767
print('Writing', filename)
6868
writer = tf.python_io.TFRecordWriter(filename)
6969
for index in range(num_examples):
@@ -80,15 +80,15 @@ def convert_to(data_set, name):
8080

8181
def main(unused_argv):
8282
# Get the data.
83-
data_sets = mnist.read_data_sets(FLAGS.directory,
83+
datasets = mnist.read_data_sets(FLAGS.directory,
8484
dtype=tf.uint8,
8585
reshape=False,
8686
validation_size=FLAGS.validation_size)
8787

8888
# Convert to Examples and write the result to TFRecords.
89-
convert_to(data_sets.train, 'train')
90-
convert_to(data_sets.validation, 'validation')
91-
convert_to(data_sets.test, 'test')
89+
convert_to(datasets.train, 'train', FLAGS.directory)
90+
convert_to(datasets.validation, 'validation', FLAGS.directory)
91+
convert_to(datasets.test, 'test', FLAGS.directory)
9292

9393

9494
if __name__ == '__main__':

‎official/mnist/mnist.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
}
5353

5454

55-
def input_fn(mode, batch_size=1):
55+
def input_fn(is_training, filename, batch_size=1, num_epochs=1):
5656
"""A simple input_fn using the contrib.data input pipeline."""
5757

5858
def example_parser(serialized_example):
@@ -71,21 +71,15 @@ def example_parser(serialized_example):
7171
label = tf.cast(features['label'], tf.int32)
7272
return image, tf.one_hot(label, 10)
7373

74-
if mode == tf.estimator.ModeKeys.TRAIN:
75-
tfrecords_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
76-
else:
77-
assert mode == tf.estimator.ModeKeys.EVAL, 'invalid mode'
78-
tfrecords_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
74+
dataset = tf.contrib.data.TFRecordDataset([filename])
7975

80-
assert tf.gfile.Exists(tfrecords_file), (
81-
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
82-
'file format.')
76+
if is_training:
77+
# When choosing shuffle buffer sizes, larger sizes result in better
78+
# randomness, while smaller sizes have better performance. Because MNIST is
79+
# a small dataset, we can easily shuffle the full epoch.
80+
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
8381

84-
dataset = tf.contrib.data.TFRecordDataset([tfrecords_file])
85-
86-
# For training, repeat the dataset forever
87-
if mode == tf.estimator.ModeKeys.TRAIN:
88-
dataset = dataset.repeat()
82+
dataset = dataset.repeat(num_epochs)
8983

9084
# Map example_parser over dataset, and batch results by up to batch_size
9185
dataset = dataset.map(
@@ -96,13 +90,12 @@ def example_parser(serialized_example):
9690
return images, labels
9791

9892

99-
def mnist_model(inputs, mode):
93+
def mnist_model(inputs, mode, data_format):
10094
"""Takes the MNIST inputs and mode and outputs a tensor of logits."""
10195
# Input Layer
10296
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
10397
# MNIST images are 28x28 pixels, and have one color channel
10498
inputs = tf.reshape(inputs, [-1, 28, 28, 1])
105-
data_format = FLAGS.data_format
10699

107100
if data_format is None:
108101
# When running on GPU, transpose the data from channels_last (NHWC) to
@@ -177,9 +170,9 @@ def mnist_model(inputs, mode):
177170
return logits
178171

179172

180-
def mnist_model_fn(features, labels, mode):
173+
def mnist_model_fn(features, labels, mode, params):
181174
"""Model function for MNIST."""
182-
logits = mnist_model(features, mode)
175+
logits = mnist_model(features, mode, params['data_format'])
183176

184177
predictions = {
185178
'classes': tf.argmax(input=logits, axis=1),
@@ -215,30 +208,36 @@ def mnist_model_fn(features, labels, mode):
215208

216209

217210
def main(unused_argv):
211+
# Make sure that training and testing data have been converted.
212+
train_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
213+
test_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
214+
assert (tf.gfile.Exists(train_file) and tf.gfile.Exists(test_file)), (
215+
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
216+
'file format.')
217+
218218
# Create the Estimator
219219
mnist_classifier = tf.estimator.Estimator(
220-
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir)
220+
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir,
221+
params={'data_format': FLAGS.data_format})
221222

222-
# Train the model
223+
# Set up training hook that logs the training accuracy every 100 steps.
223224
tensors_to_log = {
224225
'train_accuracy': 'train_accuracy'
225226
}
226-
227227
logging_hook = tf.train.LoggingTensorHook(
228228
tensors=tensors_to_log, every_n_iter=100)
229229

230-
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
231-
230+
# Train the model
232231
mnist_classifier.train(
233-
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, FLAGS.batch_size),
234-
steps=FLAGS.train_epochs * batches_per_epoch,
232+
input_fn=lambda: input_fn(
233+
True, train_file, FLAGS.batch_size, FLAGS.train_epochs),
235234
hooks=[logging_hook])
236235

237236
# Evaluate the model and print results
238237
eval_results = mnist_classifier.evaluate(
239-
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL))
238+
input_fn=lambda: input_fn(False, test_file, FLAGS.batch_size))
240239
print()
241-
print('Evaluation results:\n %s' % eval_results)
240+
print('Evaluation results:\n\t%s' % eval_results)
242241

243242

244243
if __name__ == '__main__':

‎official/mnist/mnist_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def input_fn(self):
3434
def mnist_model_fn_helper(self, mode):
3535
features, labels = self.input_fn()
3636
image_count = features.shape[0]
37-
spec = mnist.mnist_model_fn(features, labels, mode)
37+
spec = mnist.mnist_model_fn(
38+
features, labels, mode, {'data_format': 'channels_last'})
3839

3940
predictions = spec.predictions
4041
self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
@@ -65,5 +66,4 @@ def test_mnist_model_fn_predict_mode(self):
6566

6667

6768
if __name__ == '__main__':
68-
mnist.FLAGS = mnist.parser.parse_args()
6969
tf.test.main()

‎official/resnet/cifar10_main.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,18 @@
7171
'validation': 10000,
7272
}
7373

74+
_SHUFFLE_BUFFER = 20000
75+
7476

7577
def record_dataset(filenames):
7678
"""Returns an input pipeline Dataset from `filenames`."""
7779
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
7880
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
7981

8082

81-
def get_filenames(is_training):
83+
def get_filenames(is_training, data_dir):
8284
"""Returns a list of filenames."""
83-
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
85+
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
8486

8587
assert os.path.exists(data_dir), (
8688
'Run cifar10_download_and_extract.py first to download and extract the '
@@ -135,7 +137,7 @@ def train_preprocess_fn(image, label):
135137
return image, label
136138

137139

138-
def input_fn(is_training, num_epochs=1):
140+
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
139141
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
140142
141143
Args:
@@ -145,42 +147,41 @@ def input_fn(is_training, num_epochs=1):
145147
Returns:
146148
A tuple of images and labels.
147149
"""
148-
dataset = record_dataset(get_filenames(is_training))
150+
dataset = record_dataset(get_filenames(is_training, data_dir))
149151
dataset = dataset.map(dataset_parser, num_threads=1,
150-
output_buffer_size=2 * FLAGS.batch_size)
152+
output_buffer_size=2 * batch_size)
151153

152154
# For training, preprocess the image and shuffle.
153155
if is_training:
154156
dataset = dataset.map(train_preprocess_fn, num_threads=1,
155-
output_buffer_size=2 * FLAGS.batch_size)
157+
output_buffer_size=2 * batch_size)
156158

157-
# Ensure that the capacity is sufficiently large to provide good random
158-
# shuffling.
159-
buffer_size = int(0.4 * _NUM_IMAGES['train'])
160-
dataset = dataset.shuffle(buffer_size=buffer_size)
159+
# When choosing shuffle buffer sizes, larger sizes result in better
160+
# randomness, while smaller sizes have better performance.
161+
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
161162

162163
# Subtract off the mean and divide by the variance of the pixels.
163164
dataset = dataset.map(
164165
lambda image, label: (tf.image.per_image_standardization(image), label),
165166
num_threads=1,
166-
output_buffer_size=2 * FLAGS.batch_size)
167+
output_buffer_size=2 * batch_size)
167168

168169
dataset = dataset.repeat(num_epochs)
169170

170171
# Batch results by up to batch_size, and then fetch the tuple from the
171172
# iterator.
172-
iterator = dataset.batch(FLAGS.batch_size).make_one_shot_iterator()
173+
iterator = dataset.batch(batch_size).make_one_shot_iterator()
173174
images, labels = iterator.get_next()
174175

175176
return images, labels
176177

177178

178-
def cifar10_model_fn(features, labels, mode):
179+
def cifar10_model_fn(features, labels, mode, params):
179180
"""Model function for CIFAR-10."""
180181
tf.summary.image('images', features, max_outputs=6)
181182

182183
network = resnet_model.cifar10_resnet_v2_generator(
183-
FLAGS.resnet_size, _NUM_CLASSES, FLAGS.data_format)
184+
params['resnet_size'], _NUM_CLASSES, params['data_format'])
184185

185186
inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH])
186187
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
@@ -208,8 +209,8 @@ def cifar10_model_fn(features, labels, mode):
208209
if mode == tf.estimator.ModeKeys.TRAIN:
209210
# Scale the learning rate linearly with the batch size. When the batch size
210211
# is 128, the learning rate should be 0.1.
211-
initial_learning_rate = 0.1 * FLAGS.batch_size / 128
212-
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
212+
initial_learning_rate = 0.1 * params['batch_size'] / 128
213+
batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
213214
global_step = tf.train.get_or_create_global_step()
214215

215216
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
@@ -256,7 +257,12 @@ def main(unused_argv):
256257
# Set up a RunConfig to only save checkpoints once per training cycle.
257258
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
258259
cifar_classifier = tf.estimator.Estimator(
259-
model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir, config=run_config)
260+
model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir, config=run_config,
261+
params={
262+
'resnet_size': FLAGS.resnet_size,
263+
'data_format': FLAGS.data_format,
264+
'batch_size': FLAGS.batch_size,
265+
})
260266

261267
for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
262268
tensors_to_log = {
@@ -270,12 +276,12 @@ def main(unused_argv):
270276

271277
cifar_classifier.train(
272278
input_fn=lambda: input_fn(
273-
is_training=True, num_epochs=FLAGS.epochs_per_eval),
279+
True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval),
274280
hooks=[logging_hook])
275281

276282
# Evaluate the model and print results
277283
eval_results = cifar_classifier.evaluate(
278-
input_fn=lambda: input_fn(is_training=False))
284+
input_fn=lambda: input_fn(False, FLAGS.data_dir, FLAGS.batch_size))
279285
print(eval_results)
280286

281287

‎official/resnet/cifar10_test.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
tf.logging.set_verbosity(tf.logging.ERROR)
2828

29+
_BATCH_SIZE = 128
30+
2931

3032
class BaseTest(tf.test.TestCase):
3133

@@ -58,20 +60,25 @@ def test_dataset_input_fn(self):
5860
self.assertAllEqual(pixel, np.array([0, 1, 2]))
5961

6062
def input_fn(self):
61-
features = tf.random_uniform([FLAGS.batch_size, 32, 32, 3])
63+
features = tf.random_uniform([_BATCH_SIZE, 32, 32, 3])
6264
labels = tf.random_uniform(
63-
[FLAGS.batch_size], maxval=9, dtype=tf.int32)
65+
[_BATCH_SIZE], maxval=9, dtype=tf.int32)
6466
return features, tf.one_hot(labels, 10)
6567

6668
def cifar10_model_fn_helper(self, mode):
6769
features, labels = self.input_fn()
68-
spec = cifar10_main.cifar10_model_fn(features, labels, mode)
70+
spec = cifar10_main.cifar10_model_fn(
71+
features, labels, mode, {
72+
'resnet_size': 32,
73+
'data_format': 'channels_last',
74+
'batch_size': _BATCH_SIZE,
75+
})
6976

7077
predictions = spec.predictions
7178
self.assertAllEqual(predictions['probabilities'].shape,
72-
(FLAGS.batch_size, 10))
79+
(_BATCH_SIZE, 10))
7380
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
74-
self.assertAllEqual(predictions['classes'].shape, (FLAGS.batch_size,))
81+
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
7582
self.assertEqual(predictions['classes'].dtype, tf.int64)
7683

7784
if mode != tf.estimator.ModeKeys.PREDICT:
@@ -97,6 +104,4 @@ def test_cifar10_model_fn_predict_mode(self):
97104

98105

99106
if __name__ == '__main__':
100-
cifar10_main.FLAGS = cifar10_main.parser.parse_args()
101-
FLAGS = cifar10_main.FLAGS
102107
tf.test.main()

‎official/resnet/imagenet_main.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,18 @@
7373
'validation': 50000,
7474
}
7575

76+
_SHUFFLE_BUFFER = 1500
7677

77-
def filenames(is_training):
78+
79+
def filenames(is_training, data_dir):
7880
"""Return filenames for dataset."""
7981
if is_training:
8082
return [
81-
os.path.join(FLAGS.data_dir, 'train-%05d-of-01024' % i)
83+
os.path.join(data_dir, 'train-%05d-of-01024' % i)
8284
for i in range(0, 1024)]
8385
else:
8486
return [
85-
os.path.join(FLAGS.data_dir, 'validation-%05d-of-00128' % i)
87+
os.path.join(data_dir, 'validation-%05d-of-00128' % i)
8688
for i in range(0, 128)]
8789

8890

@@ -129,9 +131,11 @@ def dataset_parser(value, is_training):
129131
return image, tf.one_hot(label, _LABEL_CLASSES)
130132

131133

132-
def input_fn(is_training, num_epochs=1):
134+
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
133135
"""Input function which provides batches for train or eval."""
134-
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames(is_training))
136+
dataset = tf.contrib.data.Dataset.from_tensor_slices(
137+
filenames(is_training, data_dir))
138+
135139
if is_training:
136140
dataset = dataset.shuffle(buffer_size=1024)
137141
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
@@ -141,23 +145,24 @@ def input_fn(is_training, num_epochs=1):
141145

142146
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
143147
num_threads=5,
144-
output_buffer_size=FLAGS.batch_size)
148+
output_buffer_size=batch_size)
145149

146150
if is_training:
147-
buffer_size = 1250 + 2 * FLAGS.batch_size
148-
dataset = dataset.shuffle(buffer_size=buffer_size)
151+
# When choosing shuffle buffer sizes, larger sizes result in better
152+
# randomness, while smaller sizes have better performance.
153+
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
149154

150-
iterator = dataset.batch(FLAGS.batch_size).make_one_shot_iterator()
155+
iterator = dataset.batch(batch_size).make_one_shot_iterator()
151156
images, labels = iterator.get_next()
152157
return images, labels
153158

154159

155-
def resnet_model_fn(features, labels, mode):
160+
def resnet_model_fn(features, labels, mode, params):
156161
"""Our model_fn for ResNet to be used with our Estimator."""
157162
tf.summary.image('images', features, max_outputs=6)
158163

159164
network = resnet_model.imagenet_resnet_v2(
160-
FLAGS.resnet_size, _LABEL_CLASSES, FLAGS.data_format)
165+
params['resnet_size'], _LABEL_CLASSES, params['data_format'])
161166
logits = network(
162167
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
163168

@@ -185,8 +190,8 @@ def resnet_model_fn(features, labels, mode):
185190
if mode == tf.estimator.ModeKeys.TRAIN:
186191
# Scale the learning rate linearly with the batch size. When the batch size is
187192
# 256, the learning rate should be 0.1.
188-
initial_learning_rate = 0.1 * FLAGS.batch_size / 256
189-
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
193+
initial_learning_rate = 0.1 * params['batch_size'] / 256
194+
batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
190195
global_step = tf.train.get_or_create_global_step()
191196

192197
# Multiply the learning rate by 0.1 at 30, 60, 80, and 90 epochs.
@@ -235,7 +240,12 @@ def main(unused_argv):
235240
# Set up a RunConfig to only save checkpoints once per training cycle.
236241
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
237242
resnet_classifier = tf.estimator.Estimator(
238-
model_fn=resnet_model_fn, model_dir=FLAGS.model_dir, config=run_config)
243+
model_fn=resnet_model_fn, model_dir=FLAGS.model_dir, config=run_config,
244+
params={
245+
'resnet_size': FLAGS.resnet_size,
246+
'data_format': FLAGS.data_format,
247+
'batch_size': FLAGS.batch_size,
248+
})
239249

240250
for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
241251
tensors_to_log = {
@@ -250,12 +260,12 @@ def main(unused_argv):
250260
print('Starting a training cycle.')
251261
resnet_classifier.train(
252262
input_fn=lambda: input_fn(
253-
is_training=True, num_epochs=FLAGS.epochs_per_eval),
263+
True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval),
254264
hooks=[logging_hook])
255265

256266
print('Starting to evaluate.')
257267
eval_results = resnet_classifier.evaluate(
258-
input_fn=lambda: input_fn(is_training=False))
268+
input_fn=lambda: input_fn(False, FLAGS.data_dir, FLAGS.batch_size))
259269
print(eval_results)
260270

261271

‎official/resnet/imagenet_test.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
tf.logging.set_verbosity(tf.logging.ERROR)
2828

29+
_BATCH_SIZE = 32
2930
_LABEL_CLASSES = 1001
3031

3132

@@ -125,10 +126,10 @@ def test_tensor_shapes_resnet_200_with_gpu(self):
125126

126127
def input_fn(self):
127128
"""Provides random features and labels."""
128-
features = tf.random_uniform([FLAGS.batch_size, 224, 224, 3])
129+
features = tf.random_uniform([_BATCH_SIZE, 224, 224, 3])
129130
labels = tf.one_hot(
130131
tf.random_uniform(
131-
[FLAGS.batch_size], maxval=_LABEL_CLASSES - 1,
132+
[_BATCH_SIZE], maxval=_LABEL_CLASSES - 1,
132133
dtype=tf.int32),
133134
_LABEL_CLASSES)
134135

@@ -139,13 +140,18 @@ def resnet_model_fn_helper(self, mode):
139140
tf.train.create_global_step()
140141

141142
features, labels = self.input_fn()
142-
spec = imagenet_main.resnet_model_fn(features, labels, mode)
143+
spec = imagenet_main.resnet_model_fn(
144+
features, labels, mode, {
145+
'resnet_size': 50,
146+
'data_format': 'channels_last',
147+
'batch_size': _BATCH_SIZE,
148+
})
143149

144150
predictions = spec.predictions
145151
self.assertAllEqual(predictions['probabilities'].shape,
146-
(FLAGS.batch_size, _LABEL_CLASSES))
152+
(_BATCH_SIZE, _LABEL_CLASSES))
147153
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
148-
self.assertAllEqual(predictions['classes'].shape, (FLAGS.batch_size,))
154+
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
149155
self.assertEqual(predictions['classes'].dtype, tf.int64)
150156

151157
if mode != tf.estimator.ModeKeys.PREDICT:
@@ -171,6 +177,4 @@ def test_resnet_model_fn_predict_mode(self):
171177

172178

173179
if __name__ == '__main__':
174-
imagenet_main.FLAGS = imagenet_main.parser.parse_args()
175-
FLAGS = imagenet_main.FLAGS
176180
tf.test.main()

0 commit comments

Comments
 (0)
Please sign in to comment.