1
+ """Functions for downloading and reading MNIST data."""
2
+ from __future__ import print_function
3
+ import gzip
4
+ import os
5
+ import urllib
6
+ import numpy
7
+ SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
8
+ def maybe_download (filename , work_directory ):
9
+ """Download the data from Yann's website, unless it's already here."""
10
+ if not os .path .exists (work_directory ):
11
+ os .mkdir (work_directory )
12
+ filepath = os .path .join (work_directory , filename )
13
+ if not os .path .exists (filepath ):
14
+ filepath , _ = urllib .urlretrieve (SOURCE_URL + filename , filepath )
15
+ statinfo = os .stat (filepath )
16
+ print ('Succesfully downloaded' , filename , statinfo .st_size , 'bytes.' )
17
+ return filepath
18
+ def _read32 (bytestream ):
19
+ dt = numpy .dtype (numpy .uint32 ).newbyteorder ('>' )
20
+ return numpy .frombuffer (bytestream .read (4 ), dtype = dt )
21
+ def extract_images (filename ):
22
+ """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
23
+ print ('Extracting' , filename )
24
+ with gzip .open (filename ) as bytestream :
25
+ magic = _read32 (bytestream )
26
+ if magic != 2051 :
27
+ raise ValueError (
28
+ 'Invalid magic number %d in MNIST image file: %s' %
29
+ (magic , filename ))
30
+ num_images = _read32 (bytestream )
31
+ rows = _read32 (bytestream )
32
+ cols = _read32 (bytestream )
33
+ buf = bytestream .read (rows * cols * num_images )
34
+ data = numpy .frombuffer (buf , dtype = numpy .uint8 )
35
+ data = data .reshape (num_images , rows , cols , 1 )
36
+ return data
37
+ def dense_to_one_hot (labels_dense , num_classes = 10 ):
38
+ """Convert class labels from scalars to one-hot vectors."""
39
+ num_labels = labels_dense .shape [0 ]
40
+ index_offset = numpy .arange (num_labels ) * num_classes
41
+ labels_one_hot = numpy .zeros ((num_labels , num_classes ))
42
+ labels_one_hot .flat [index_offset + labels_dense .ravel ()] = 1
43
+ return labels_one_hot
44
+ def extract_labels (filename , one_hot = False ):
45
+ """Extract the labels into a 1D uint8 numpy array [index]."""
46
+ print ('Extracting' , filename )
47
+ with gzip .open (filename ) as bytestream :
48
+ magic = _read32 (bytestream )
49
+ if magic != 2049 :
50
+ raise ValueError (
51
+ 'Invalid magic number %d in MNIST label file: %s' %
52
+ (magic , filename ))
53
+ num_items = _read32 (bytestream )
54
+ buf = bytestream .read (num_items )
55
+ labels = numpy .frombuffer (buf , dtype = numpy .uint8 )
56
+ if one_hot :
57
+ return dense_to_one_hot (labels )
58
+ return labels
59
+ class DataSet (object ):
60
+ def __init__ (self , images , labels , fake_data = False ):
61
+ if fake_data :
62
+ self ._num_examples = 10000
63
+ else :
64
+ assert images .shape [0 ] == labels .shape [0 ], (
65
+ "images.shape: %s labels.shape: %s" % (images .shape ,
66
+ labels .shape ))
67
+ self ._num_examples = images .shape [0 ]
68
+ # Convert shape from [num examples, rows, columns, depth]
69
+ # to [num examples, rows*columns] (assuming depth == 1)
70
+ assert images .shape [3 ] == 1
71
+ images = images .reshape (images .shape [0 ],
72
+ images .shape [1 ] * images .shape [2 ])
73
+ # Convert from [0, 255] -> [0.0, 1.0].
74
+ images = images .astype (numpy .float32 )
75
+ images = numpy .multiply (images , 1.0 / 255.0 )
76
+ self ._images = images
77
+ self ._labels = labels
78
+ self ._epochs_completed = 0
79
+ self ._index_in_epoch = 0
80
+ @property
81
+ def images (self ):
82
+ return self ._images
83
+ @property
84
+ def labels (self ):
85
+ return self ._labels
86
+ @property
87
+ def num_examples (self ):
88
+ return self ._num_examples
89
+ @property
90
+ def epochs_completed (self ):
91
+ return self ._epochs_completed
92
+ def next_batch (self , batch_size , fake_data = False ):
93
+ """Return the next `batch_size` examples from this data set."""
94
+ if fake_data :
95
+ fake_image = [1.0 for _ in xrange (784 )]
96
+ fake_label = 0
97
+ return [fake_image for _ in xrange (batch_size )], [
98
+ fake_label for _ in xrange (batch_size )]
99
+ start = self ._index_in_epoch
100
+ self ._index_in_epoch += batch_size
101
+ if self ._index_in_epoch > self ._num_examples :
102
+ # Finished epoch
103
+ self ._epochs_completed += 1
104
+ # Shuffle the data
105
+ perm = numpy .arange (self ._num_examples )
106
+ numpy .random .shuffle (perm )
107
+ self ._images = self ._images [perm ]
108
+ self ._labels = self ._labels [perm ]
109
+ # Start next epoch
110
+ start = 0
111
+ self ._index_in_epoch = batch_size
112
+ assert batch_size <= self ._num_examples
113
+ end = self ._index_in_epoch
114
+ return self ._images [start :end ], self ._labels [start :end ]
115
+ def read_data_sets (train_dir , fake_data = False , one_hot = False ):
116
+ class DataSets (object ):
117
+ pass
118
+ data_sets = DataSets ()
119
+ if fake_data :
120
+ data_sets .train = DataSet ([], [], fake_data = True )
121
+ data_sets .validation = DataSet ([], [], fake_data = True )
122
+ data_sets .test = DataSet ([], [], fake_data = True )
123
+ return data_sets
124
+ TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
125
+ TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
126
+ TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
127
+ TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
128
+ VALIDATION_SIZE = 5000
129
+ local_file = maybe_download (TRAIN_IMAGES , train_dir )
130
+ train_images = extract_images (local_file )
131
+ local_file = maybe_download (TRAIN_LABELS , train_dir )
132
+ train_labels = extract_labels (local_file , one_hot = one_hot )
133
+ local_file = maybe_download (TEST_IMAGES , train_dir )
134
+ test_images = extract_images (local_file )
135
+ local_file = maybe_download (TEST_LABELS , train_dir )
136
+ test_labels = extract_labels (local_file , one_hot = one_hot )
137
+ validation_images = train_images [:VALIDATION_SIZE ]
138
+ validation_labels = train_labels [:VALIDATION_SIZE ]
139
+ train_images = train_images [VALIDATION_SIZE :]
140
+ train_labels = train_labels [VALIDATION_SIZE :]
141
+ data_sets .train = DataSet (train_images , train_labels )
142
+ data_sets .validation = DataSet (validation_images , validation_labels )
143
+ data_sets .test = DataSet (test_images , test_labels )
144
+ return data_sets
0 commit comments