71
71
'validation' : 10000 ,
72
72
}
73
73
74
+ _SHUFFLE_BUFFER = 20000
75
+
74
76
75
77
def record_dataset (filenames ):
76
78
"""Returns an input pipeline Dataset from `filenames`."""
77
79
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
78
80
return tf .contrib .data .FixedLengthRecordDataset (filenames , record_bytes )
79
81
80
82
81
- def get_filenames (is_training ):
83
+ def get_filenames (is_training , data_dir ):
82
84
"""Returns a list of filenames."""
83
- data_dir = os .path .join (FLAGS . data_dir , 'cifar-10-batches-bin' )
85
+ data_dir = os .path .join (data_dir , 'cifar-10-batches-bin' )
84
86
85
87
assert os .path .exists (data_dir ), (
86
88
'Run cifar10_download_and_extract.py first to download and extract the '
@@ -135,7 +137,7 @@ def train_preprocess_fn(image, label):
135
137
return image , label
136
138
137
139
138
- def input_fn (is_training , num_epochs = 1 ):
140
+ def input_fn (is_training , data_dir , batch_size , num_epochs = 1 ):
139
141
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
140
142
141
143
Args:
@@ -145,42 +147,41 @@ def input_fn(is_training, num_epochs=1):
145
147
Returns:
146
148
A tuple of images and labels.
147
149
"""
148
- dataset = record_dataset (get_filenames (is_training ))
150
+ dataset = record_dataset (get_filenames (is_training , data_dir ))
149
151
dataset = dataset .map (dataset_parser , num_threads = 1 ,
150
- output_buffer_size = 2 * FLAGS . batch_size )
152
+ output_buffer_size = 2 * batch_size )
151
153
152
154
# For training, preprocess the image and shuffle.
153
155
if is_training :
154
156
dataset = dataset .map (train_preprocess_fn , num_threads = 1 ,
155
- output_buffer_size = 2 * FLAGS . batch_size )
157
+ output_buffer_size = 2 * batch_size )
156
158
157
- # Ensure that the capacity is sufficiently large to provide good random
158
- # shuffling.
159
- buffer_size = int (0.4 * _NUM_IMAGES ['train' ])
160
- dataset = dataset .shuffle (buffer_size = buffer_size )
159
+ # When choosing shuffle buffer sizes, larger sizes result in better
160
+ # randomness, while smaller sizes have better performance.
161
+ dataset = dataset .shuffle (buffer_size = _SHUFFLE_BUFFER )
161
162
162
163
# Subtract off the mean and divide by the variance of the pixels.
163
164
dataset = dataset .map (
164
165
lambda image , label : (tf .image .per_image_standardization (image ), label ),
165
166
num_threads = 1 ,
166
- output_buffer_size = 2 * FLAGS . batch_size )
167
+ output_buffer_size = 2 * batch_size )
167
168
168
169
dataset = dataset .repeat (num_epochs )
169
170
170
171
# Batch results by up to batch_size, and then fetch the tuple from the
171
172
# iterator.
172
- iterator = dataset .batch (FLAGS . batch_size ).make_one_shot_iterator ()
173
+ iterator = dataset .batch (batch_size ).make_one_shot_iterator ()
173
174
images , labels = iterator .get_next ()
174
175
175
176
return images , labels
176
177
177
178
178
- def cifar10_model_fn (features , labels , mode ):
179
+ def cifar10_model_fn (features , labels , mode , params ):
179
180
"""Model function for CIFAR-10."""
180
181
tf .summary .image ('images' , features , max_outputs = 6 )
181
182
182
183
network = resnet_model .cifar10_resnet_v2_generator (
183
- FLAGS . resnet_size , _NUM_CLASSES , FLAGS . data_format )
184
+ params [ ' resnet_size' ] , _NUM_CLASSES , params [ ' data_format' ] )
184
185
185
186
inputs = tf .reshape (features , [- 1 , _HEIGHT , _WIDTH , _DEPTH ])
186
187
logits = network (inputs , mode == tf .estimator .ModeKeys .TRAIN )
@@ -208,8 +209,8 @@ def cifar10_model_fn(features, labels, mode):
208
209
if mode == tf .estimator .ModeKeys .TRAIN :
209
210
# Scale the learning rate linearly with the batch size. When the batch size
210
211
# is 128, the learning rate should be 0.1.
211
- initial_learning_rate = 0.1 * FLAGS . batch_size / 128
212
- batches_per_epoch = _NUM_IMAGES ['train' ] / FLAGS . batch_size
212
+ initial_learning_rate = 0.1 * params [ ' batch_size' ] / 128
213
+ batches_per_epoch = _NUM_IMAGES ['train' ] / params [ ' batch_size' ]
213
214
global_step = tf .train .get_or_create_global_step ()
214
215
215
216
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
@@ -256,7 +257,12 @@ def main(unused_argv):
256
257
# Set up a RunConfig to only save checkpoints once per training cycle.
257
258
run_config = tf .estimator .RunConfig ().replace (save_checkpoints_secs = 1e9 )
258
259
cifar_classifier = tf .estimator .Estimator (
259
- model_fn = cifar10_model_fn , model_dir = FLAGS .model_dir , config = run_config )
260
+ model_fn = cifar10_model_fn , model_dir = FLAGS .model_dir , config = run_config ,
261
+ params = {
262
+ 'resnet_size' : FLAGS .resnet_size ,
263
+ 'data_format' : FLAGS .data_format ,
264
+ 'batch_size' : FLAGS .batch_size ,
265
+ })
260
266
261
267
for _ in range (FLAGS .train_epochs // FLAGS .epochs_per_eval ):
262
268
tensors_to_log = {
@@ -270,12 +276,12 @@ def main(unused_argv):
270
276
271
277
cifar_classifier .train (
272
278
input_fn = lambda : input_fn (
273
- is_training = True , num_epochs = FLAGS .epochs_per_eval ),
279
+ True , FLAGS . data_dir , FLAGS . batch_size , FLAGS .epochs_per_eval ),
274
280
hooks = [logging_hook ])
275
281
276
282
# Evaluate the model and print results
277
283
eval_results = cifar_classifier .evaluate (
278
- input_fn = lambda : input_fn (is_training = False ))
284
+ input_fn = lambda : input_fn (False , FLAGS . data_dir , FLAGS . batch_size ))
279
285
print (eval_results )
280
286
281
287
0 commit comments