Skip to content

Commit a698a4e

Browse files
authored
[jena-weather] Fix baseline; add unit tests; upgrade to tfjs 1.1.0 (tensorflow#271)
- Fix a crash in the `--modelType baseline` command - Add unit tests - Upgrade to tfjs and tfjs-node/tfjs-node-gpu 1.1.0 - Add early stopping with a default patience of 2
1 parent e6d0b76 commit a698a4e

File tree

7 files changed

+258
-67
lines changed

7 files changed

+258
-67
lines changed

jena-weather/data_test.js

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {JenaWeatherData} from "./data";
19+
20+
global.fetch = require('node-fetch');
21+
22+
describe('JenaWeatherData', () => {
23+
it('construct, load and basic public methods', async () => {
24+
const dataset = new JenaWeatherData();
25+
await dataset.load();
26+
27+
expect(dataset.getDataColumnNames().length).toEqual(14);
28+
expect(dataset.getDataColumnNames()[0]).toEqual('p (mbar)');
29+
expect(dataset.getDataColumnNames()[1]).toEqual('T (degC)');
30+
31+
expect(new Date(dataset.getTime(0)).getTime()).toBeGreaterThan(0);
32+
33+
const columnData = dataset.getColumnData('T (degC)', false, true, 0, 30, 3);
34+
expect(columnData.length).toEqual(10);
35+
36+
const func = dataset.getNextBatchFunction(
37+
true, 1000, 100, 32, 10, 0, 10000, true, false);
38+
for (let i = 0; i < 2; ++i) {
39+
const item = func.next();
40+
expect(item.done).toEqual(false);
41+
expect(item.value.xs.shape).toEqual([32, 100, 14]);
42+
}
43+
});
44+
});

jena-weather/models.js

+21-18
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
*/
1717

1818
/**
19-
* Creating and training `tf.Model`s for the temperature prediction problem.
19+
* Creating and training `tf.LayersModel`s for the temperature prediction
20+
* problem.
2021
*
2122
* This file is used to create models for both
2223
* - the browser: see [index.js](./index.js), and
@@ -51,23 +52,24 @@ const VAL_MAX_ROW = 300000;
5152
export async function getBaselineMeanAbsoluteError(
5253
jenaWeatherData, normalize, includeDateTime, lookBack, step, delay) {
5354
const batchSize = 128;
54-
const nextBatchFn = jenaWeatherData.getNextBatchFunction(
55-
false, lookBack, delay, batchSize, step, VAL_MIN_ROW, VAL_MAX_ROW,
56-
normalize, includeDateTime);
57-
const dataset = tf.data.generator(nextBatchFn);
55+
const dataset = tf.data.generator(
56+
() => jenaWeatherData.getNextBatchFunction(
57+
false, lookBack, delay, batchSize, step, VAL_MIN_ROW,
58+
VAL_MAX_ROW, normalize, includeDateTime));
5859

5960
const batchMeanAbsoluteErrors = [];
6061
const batchSizes = [];
6162
await dataset.forEach(dataItem => {
62-
const features = dataItem[0];
63-
const targets = dataItem[1];
63+
const features = dataItem.xs;
64+
const targets = dataItem.ys;
6465
const timeSteps = features.shape[1];
6566
batchSizes.push(features.shape[0]);
6667
batchMeanAbsoluteErrors.push(tf.tidy(
6768
() => tf.losses.absoluteDifference(
6869
targets,
6970
features.gather([timeSteps - 1], 1).gather([1], 2).squeeze([2]))));
7071
});
72+
7173
const meanAbsoluteError = tf.tidy(() => {
7274
const batchSizesTensor = tf.tensor1d(batchSizes);
7375
const batchMeanAbsoluteErrorsTensor = tf.stack(batchMeanAbsoluteErrors);
@@ -83,7 +85,7 @@ export async function getBaselineMeanAbsoluteError(
8385
* Build a linear-regression model for the temperature-prediction problem.
8486
*
8587
* @param {tf.Shape} inputShape Input shape (without the batch dimenson).
86-
* @returns {tf.Model} A TensorFlow.js tf.Model instance.
88+
* @returns {tf.LayersModel} A TensorFlow.js tf.LayersModel instance.
8789
*/
8890
function buildLinearRegressionModel(inputShape) {
8991
const model = tf.sequential();
@@ -102,9 +104,9 @@ function buildLinearRegressionModel(inputShape) {
102104
* @param {number} dropoutRate Dropout rate of an optional dropout layer
103105
* inserted between the two dense layers of the MLP. Optional. If not
104106
* specified, no dropout layers will be included in the MLP.
105-
* @returns {tf.Model} A TensorFlow.js tf.Model instance.
107+
* @returns {tf.LayersModel} A TensorFlow.js tf.LayersModel instance.
106108
*/
107-
function buildMLPModel(inputShape, kernelRegularizer, dropoutRate) {
109+
export function buildMLPModel(inputShape, kernelRegularizer, dropoutRate) {
108110
const model = tf.sequential();
109111
model.add(tf.layers.flatten({inputShape}));
110112
model.add(
@@ -120,9 +122,10 @@ function buildMLPModel(inputShape, kernelRegularizer, dropoutRate) {
120122
* Build a simpleRNN-based model for the temperature-prediction problem.
121123
*
122124
* @param {tf.Shape} inputShape Input shape (without the batch dimenson).
123-
* @returns {tf.Model} A TensorFlow.js model consisting of a simpleRNN layer.
125+
* @returns {tf.LayersModel} A TensorFlow.js model consisting of a simpleRNN
126+
* layer.
124127
*/
125-
function buildSimpleRNNModel(inputShape) {
128+
export function buildSimpleRNNModel(inputShape) {
126129
const model = tf.sequential();
127130
const rnnUnits = 32;
128131
model.add(tf.layers.simpleRNN({
@@ -139,9 +142,9 @@ function buildSimpleRNNModel(inputShape) {
139142
* @param {tf.Shape} inputShape Input shape (without the batch dimenson).
140143
* @param {number} dropout Optional input dropout rate
141144
* @param {number} recurrentDropout Optional recurrent dropout rate.
142-
* @returns {tf.Model} A TensorFlow.js GRU model.
145+
* @returns {tf.LayersModel} A TensorFlow.js GRU model.
143146
*/
144-
function buildGRUModel(inputShape, dropout, recurrentDropout) {
147+
export function buildGRUModel(inputShape, dropout, recurrentDropout) {
145148
// TODO(cais): Recurrent dropout is currently not fully working.
146149
// Make it work and add a flag to train-rnn.js.
147150
const model = tf.sequential();
@@ -163,7 +166,7 @@ function buildGRUModel(inputShape, dropout, recurrentDropout) {
163166
* @param {number} numTimeSteps Number of time steps in each input.
164167
* exapmle
165168
* @param {number} numFeatures Number of features (for each time step).
166-
* @returns A compiled instance of `tf.Model`.
169+
* @returns A compiled instance of `tf.LayersModel`.
167170
*/
168171
export function buildModel(modelType, numTimeSteps, numFeatures) {
169172
const inputShape = [numTimeSteps, numFeatures];
@@ -197,9 +200,9 @@ export function buildModel(modelType, numTimeSteps, numFeatures) {
197200
/**
198201
* Train a model on the Jena weather data.
199202
*
200-
* @param {tf.Model} model A compiled tf.Model object. It is expected to
201-
* have a 3D input shape `[numExamples, timeSteps, numFeatures].` and an
202-
* output shape `[numExamples, 1]` for predicting the temperature value.
203+
* @param {tf.LayersModel} model A compiled tf.LayersModel object. It is
204+
* expected to have a 3D input shape `[numExamples, timeSteps, numFeatures].`
205+
* and an output shape `[numExamples, 1]` for predicting the temperature value.
203206
* @param {JenaWeatherData} jenaWeatherData A JenaWeatherData object.
204207
* @param {boolean} normalize Whether to used normalized data for training.
205208
* @param {boolean} includeDateTime Whether to include date and time features

jena-weather/models_test.js

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs-node';
19+
20+
import {JenaWeatherData} from "./data";
21+
import {buildGRUModel, buildMLPModel, buildSimpleRNNModel, getBaselineMeanAbsoluteError} from "./models";
22+
23+
describe('Model creation', () => {
24+
it('MLP', () => {
25+
const model = buildMLPModel([8, 9]);
26+
const config = model.layers[1].getConfig();
27+
expect(config.kernelRegularizer).toEqual(null);
28+
expect(model.inputs.length).toEqual(1);
29+
expect(model.inputs[0].shape).toEqual([null, 8, 9]);
30+
expect(model.outputs.length).toEqual(1);
31+
expect(model.outputs[0].shape).toEqual([null, 1]);
32+
});
33+
34+
it('MLP with kernel regularizer', () => {
35+
const model = buildMLPModel([8, 9], tf.regularizers.l2({l2: 5e-2}));
36+
const config = model.layers[1].getConfig();
37+
expect(config.kernelRegularizer.config.l2).toEqual(5e-2);
38+
expect(model.inputs.length).toEqual(1);
39+
expect(model.inputs[0].shape).toEqual([null, 8, 9]);
40+
expect(model.outputs.length).toEqual(1);
41+
expect(model.outputs[0].shape).toEqual([null, 1]);
42+
});
43+
44+
it('MLP with dropout', () => {
45+
const model = buildMLPModel([8, 9], null, 0.5);
46+
const denseConfig = model.layers[1].getConfig();
47+
expect(denseConfig.kernelRegularize).toEqual(undefined);
48+
const dropoutConfig = model.layers[model.layers.length - 2].getConfig();
49+
expect(dropoutConfig.rate).toEqual(0.5);
50+
expect(model.inputs.length).toEqual(1);
51+
expect(model.inputs[0].shape).toEqual([null, 8, 9]);
52+
expect(model.outputs.length).toEqual(1);
53+
expect(model.outputs[0].shape).toEqual([null, 1]);
54+
});
55+
});
56+
57+
describe('RNN', () => {
58+
it('simpleRNN', () => {
59+
const model = buildSimpleRNNModel([8, 9]);
60+
expect(model.inputs.length).toEqual(1);
61+
expect(model.inputs[0].shape).toEqual([null, 8, 9]);
62+
expect(model.outputs.length).toEqual(1);
63+
expect(model.outputs[0].shape).toEqual([null, 1]);
64+
});
65+
66+
it('buildGRUModel', () => {
67+
const model = buildGRUModel([8, 9]);
68+
expect(model.inputs.length).toEqual(1);
69+
expect(model.inputs[0].shape).toEqual([null, 8, 9]);
70+
expect(model.outputs.length).toEqual(1);
71+
expect(model.outputs[0].shape).toEqual([null, 1]);
72+
});
73+
});
74+
75+
describe('getBaselineMeanAbsoluteError', () => {
76+
it('getBaselineMeanAbsoluteError', async () => {
77+
const dataset = new JenaWeatherData();
78+
await dataset.load();
79+
80+
const baselineMAE = await getBaselineMeanAbsoluteError(
81+
dataset, true, false, 10 * 24 * 6, 6, 24 * 6);
82+
expect(baselineMAE).toBeCloseTo(0.29033);
83+
});
84+
});

jena-weather/package.json

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12-
"@tensorflow/tfjs": "^1.0.2",
12+
"@tensorflow/tfjs": "^1.1.0",
1313
"@tensorflow/tfjs-vis": "^1.0.3"
1414
},
1515
"scripts": {
1616
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open",
1717
"build": "cross-env NODE_ENV=production parcel build index.html --no-minify --public-url ./",
1818
"link-local": "yalc link",
19+
"test": "babel-node run_tests.js",
1920
"train-rnn": "babel-node --max_old_space_size=4096 train-rnn.js"
2021
},
2122
"devDependencies": {
22-
"@tensorflow/tfjs-node": "^1.0.2",
23-
"@tensorflow/tfjs-node-gpu": "^1.0.2",
23+
"@tensorflow/tfjs-node": "^1.1.0",
24+
"@tensorflow/tfjs-node-gpu": "^1.1.0",
2425
"argparse": "^1.0.10",
2526
"babel-cli": "^6.26.0",
2627
"babel-core": "^6.26.3",

jena-weather/run_tests.js

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const jasmine_util = require('@tensorflow/tfjs-core/dist/jasmine_util');
19+
const runTests = require('../test_util').runTests;
20+
21+
runTests(jasmine_util, ['./*test.js']);

jena-weather/train-rnn.js

+20-4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ function parseArguments() {
7878
parser.addArgument(
7979
'--epochs',
8080
{type: 'int', defaultValue: 20, help: 'Number of training epochs'});
81+
parser.addArgument( '--earlyStoppingPatience', {
82+
type: 'int',
83+
defaultValue: 2,
84+
help: 'Optional patience number for EarlyStoppingCallback'
85+
});
8186
parser.addArgument('--logDir', {
8287
type: 'string',
8388
help: 'Optional tensorboard log directory, to which the loss and ' +
@@ -121,14 +126,23 @@ async function main() {
121126
const model = buildModel(
122127
args.modelType, Math.floor(args.lookBack / args.step), numFeatures);
123128

124-
let callback = null;
129+
let callback = [];
125130
if (args.logDir != null) {
126131
console.log(
127132
`Logging to tensorboard. ` +
128133
`Use the command below to bring up tensorboard server:\n` +
129134
` tensorboard --logdir ${args.logDir}`);
130-
callback = tfn.node.tensorBoard(
131-
args.logDir, {updateFreq: args.logUpdateFreq});
135+
callback.push(tfn.node.tensorBoard(args.logDir, {
136+
updateFreq: args.logUpdateFreq
137+
}));
138+
}
139+
if (args.earlyStoppingPatience != null) {
140+
console.log(
141+
`Using earlyStoppingCallback with patience ` +
142+
`${args.earlyStoppingPatience}.`);
143+
callback.push(tfn.callbacks.earlyStopping({
144+
patience: args.earlyStoppingPatience
145+
}));
132146
}
133147

134148
await trainModel(
@@ -138,4 +152,6 @@ async function main() {
138152
}
139153
}
140154

141-
main();
155+
if (require.main === module) {
156+
main();
157+
}

0 commit comments

Comments
 (0)