Skip to content

Commit 2491923

Browse files
author
Nikhil Thorat
authored
Add MNIST training demo with layers API. (tensorflow#8)
This is the same model as the mnist-core demo.
1 parent f8b51a1 commit 2491923

18 files changed

+4334
-11
lines changed

addition-rnn/index.html

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<!-- Copyright 2018 Google Inc. All Rights Reserved.
1+
<!-- Copyright 2018 Google LLC. All Rights Reserved.
22
Licensed under the Apache License, Version 2.0 (the "License");
33
you may not use this file except in compliance with the License.
44
You may obtain a copy of the License at

addition-rnn/index.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2018 Google Inc. All Rights Reserved.
3+
* Copyright 2018 Google LLC. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at

mnist-core/data.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2018 Google Inc. All Rights Reserved.
3+
* Copyright 2018 Google LLC. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -43,7 +43,7 @@ export class MnistData {
4343
this.shuffledTestIndex = 0;
4444
}
4545

46-
async fetch() {
46+
async load() {
4747
// Make a request for the MNIST sprited image.
4848
const img = new Image();
4949
const canvas = document.createElement('canvas');

mnist-core/index.html

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<!-- Copyright 2018 Google Inc. All Rights Reserved.
1+
<!-- Copyright 2018 Google LLC. All Rights Reserved.
22
Licensed under the Apache License, Version 2.0 (the "License");
33
you may not use this file except in compliance with the License.
44
You may obtain a copy of the License at

mnist-core/index.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2018 Google Inc. All Rights Reserved.
3+
* Copyright 2018 Google LLC. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -22,7 +22,7 @@ import * as ui from './ui';
2222
let data;
2323
async function load() {
2424
data = new MnistData();
25-
await data.fetch();
25+
await data.load();
2626
}
2727

2828
async function train() {

mnist-core/model.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2018 Google Inc. All Rights Reserved.
3+
* Copyright 2018 Google LLC. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at

mnist-core/ui.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2018 Google Inc. All Rights Reserved.
3+
* Copyright 2018 Google LLC. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at

mnist/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# TensorFlow.js Example: Training MNIST
2+
3+
This example shows you how to train MNIST (using the layers API).
4+
5+
Note: currently the entire dataset of MNIST images is stored in a PNG image we have
6+
sprited, and the code in `data.js` is responsible for converting it into
7+
`Tensor`s. This will become much simpler in the near future.

mnist/data.js

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
20+
const IMAGE_SIZE = 784;
21+
const NUM_CLASSES = 10;
22+
const NUM_DATASET_ELEMENTS = 65000;
23+
24+
const TRAIN_TEST_RATIO = 5 / 6;
25+
26+
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
27+
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
28+
29+
const MNIST_IMAGES_SPRITE_PATH =
30+
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
31+
const MNIST_LABELS_PATH =
32+
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';
33+
34+
/**
35+
* A class that fetches the sprited MNIST dataset and returns shuffled batches.
36+
*
37+
* NOTE: This will get much easier. For now, we do data fetching and
38+
* manipulation manually.
39+
*/
40+
export class MnistData {
41+
constructor() {
42+
this.shuffledTrainIndex = 0;
43+
this.shuffledTestIndex = 0;
44+
}
45+
46+
async load() {
47+
// Make a request for the MNIST sprited image.
48+
const img = new Image();
49+
const canvas = document.createElement('canvas');
50+
const ctx = canvas.getContext('2d');
51+
const imgRequest = new Promise((resolve, reject) => {
52+
img.crossOrigin = '';
53+
img.onload = () => {
54+
img.width = img.naturalWidth;
55+
img.height = img.naturalHeight;
56+
57+
const datasetBytesBuffer =
58+
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
59+
60+
const chunkSize = 5000;
61+
canvas.width = img.width;
62+
canvas.height = chunkSize;
63+
64+
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
65+
const datasetBytesView = new Float32Array(
66+
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
67+
IMAGE_SIZE * chunkSize);
68+
ctx.drawImage(
69+
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
70+
chunkSize);
71+
72+
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
73+
74+
for (let j = 0; j < imageData.data.length / 4; j++) {
75+
// All channels hold an equal value since the image is grayscale, so
76+
// just read the red channel.
77+
datasetBytesView[j] = imageData.data[j * 4] / 255;
78+
}
79+
}
80+
this.datasetImages = new Float32Array(datasetBytesBuffer);
81+
82+
resolve();
83+
};
84+
img.src = MNIST_IMAGES_SPRITE_PATH;
85+
});
86+
87+
const labelsRequest = fetch(MNIST_LABELS_PATH, {mode: 'arraybuffer'});
88+
const [imgResponse, labelsResponse] =
89+
await Promise.all([imgRequest, labelsRequest]);
90+
91+
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
92+
93+
// Create shuffled indices into the train/test set for when we select a
94+
// random dataset element for training / validation.
95+
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
96+
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
97+
98+
// Slice the the images and labels into train and test sets.
99+
this.trainImages =
100+
this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
101+
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
102+
this.trainLabels =
103+
this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
104+
this.testLabels =
105+
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
106+
}
107+
108+
nextTrainBatch(batchSize) {
109+
return this.nextBatch(
110+
batchSize, [this.trainImages, this.trainLabels], () => {
111+
this.shuffledTrainIndex =
112+
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
113+
return this.trainIndices[this.shuffledTrainIndex];
114+
});
115+
}
116+
117+
nextTestBatch(batchSize) {
118+
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
119+
this.shuffledTestIndex =
120+
(this.shuffledTestIndex + 1) % this.testIndices.length;
121+
return this.testIndices[this.shuffledTestIndex];
122+
});
123+
}
124+
125+
nextBatch(batchSize, data, index) {
126+
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
127+
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
128+
129+
for (let i = 0; i < batchSize; i++) {
130+
const idx = index();
131+
132+
const image =
133+
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
134+
batchImagesArray.set(image, i * IMAGE_SIZE);
135+
136+
const label =
137+
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
138+
batchLabelsArray.set(label, i * NUM_CLASSES);
139+
}
140+
141+
const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
142+
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
143+
144+
return {xs, labels};
145+
}
146+
}

mnist/index.html

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
<!-- Copyright 2018 Google LLC. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License.
11+
==============================================================================-->
12+
<html>
13+
<body>
14+
<h1>Train MNIST with the TensorFlow.js Layers API</h1>
15+
<p>Open your browser console</p>
16+
<style>
17+
canvas {
18+
width: 100px;
19+
}
20+
.pred {
21+
font-size: 30px;
22+
width: 100px;
23+
}
24+
.pred-correct {
25+
background-color: #00cf00;
26+
}
27+
.pred-incorrect {
28+
background-color: red;
29+
}
30+
.pred-container {
31+
display: inline-block;
32+
width: 100px;
33+
margin: 10px;
34+
}
35+
</style>
36+
37+
<div id="status">Loading data...</div>
38+
<div id="message"></div>
39+
<div id="images"></div>
40+
<script src="index.js"></script>
41+
</body>
42+
</html>

mnist/index.js

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
20+
import {MnistData} from './data';
21+
import * as ui from './ui';
22+
23+
const LEARNING_RATE = 0.1;
24+
const BATCH_SIZE = 64;
25+
26+
const model = tf.sequential({
27+
layers: [
28+
tf.layers.conv2d({
29+
inputShape: [28, 28, 1],
30+
kernelSize: 5,
31+
filters: 8,
32+
strides: 1,
33+
activation: 'relu',
34+
kernelInitializer: 'VarianceScaling'
35+
}),
36+
tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}),
37+
tf.layers.conv2d({
38+
kernelSize: 5,
39+
filters: 16,
40+
strides: 1,
41+
activation: 'relu',
42+
kernelInitializer: 'VarianceScaling'
43+
}),
44+
tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}),
45+
tf.layers.flatten(), tf.layers.dense({
46+
units: 10,
47+
useBias: 10,
48+
kernelInitializer: 'VarianceScaling',
49+
activation: 'softmax'
50+
})
51+
]
52+
});
53+
54+
// TODO(nsthorat): Use tf.train.sgd() once compile supports core optimizers.
55+
const optimizer = new tf.optimizers.SGD({lr: LEARNING_RATE});
56+
model.compile({optimizer, loss: 'categoricalCrossentropy'});
57+
58+
let data;
59+
async function load() {
60+
data = new MnistData();
61+
await data.load();
62+
}
63+
64+
async function train() {
65+
ui.isTraining();
66+
for (let i = 0; i < 100; i++) {
67+
const batch = data.nextTrainBatch(BATCH_SIZE);
68+
// The entire dataset doesn't fit into memory so we call fit repeatedly with
69+
// batches.
70+
const history = await model.fit({
71+
x: batch.xs.reshape([BATCH_SIZE, 28, 28, 1]),
72+
y: batch.labels,
73+
batchSize: BATCH_SIZE,
74+
epochs: 1
75+
});
76+
console.log('loss:', history.history.loss[0]);
77+
78+
batch.xs.dispose();
79+
batch.labels.dispose();
80+
}
81+
}
82+
83+
async function test() {
84+
const testExamples = 50;
85+
const batch = data.nextTestBatch(testExamples);
86+
const output = model.predict(batch.xs.reshape([-1, 28, 28, 1]));
87+
88+
const axis = 1;
89+
const labels = Array.from(batch.labels.argMax(axis).dataSync());
90+
const predictions = Array.from(output.argMax(axis).dataSync());
91+
92+
ui.showTestResults(batch, predictions, labels);
93+
}
94+
95+
async function mnist() {
96+
await load();
97+
await train();
98+
test();
99+
}
100+
mnist();

0 commit comments

Comments
 (0)