16
16
*/
17
17
18
18
/**
19
- * Creating and training `tf.Model`s for the temperature prediction problem.
19
+ * Creating and training `tf.LayersModel`s for the temperature prediction
20
+ * problem.
20
21
*
21
22
* This file is used to create models for both
22
23
* - the browser: see [index.js](./index.js), and
@@ -51,23 +52,24 @@ const VAL_MAX_ROW = 300000;
51
52
export async function getBaselineMeanAbsoluteError (
52
53
jenaWeatherData , normalize , includeDateTime , lookBack , step , delay ) {
53
54
const batchSize = 128 ;
54
- const nextBatchFn = jenaWeatherData . getNextBatchFunction (
55
- false , lookBack , delay , batchSize , step , VAL_MIN_ROW , VAL_MAX_ROW ,
56
- normalize , includeDateTime ) ;
57
- const dataset = tf . data . generator ( nextBatchFn ) ;
55
+ const dataset = tf . data . generator (
56
+ ( ) => jenaWeatherData . getNextBatchFunction (
57
+ false , lookBack , delay , batchSize , step , VAL_MIN_ROW ,
58
+ VAL_MAX_ROW , normalize , includeDateTime ) ) ;
58
59
59
60
const batchMeanAbsoluteErrors = [ ] ;
60
61
const batchSizes = [ ] ;
61
62
await dataset . forEach ( dataItem => {
62
- const features = dataItem [ 0 ] ;
63
- const targets = dataItem [ 1 ] ;
63
+ const features = dataItem . xs ;
64
+ const targets = dataItem . ys ;
64
65
const timeSteps = features . shape [ 1 ] ;
65
66
batchSizes . push ( features . shape [ 0 ] ) ;
66
67
batchMeanAbsoluteErrors . push ( tf . tidy (
67
68
( ) => tf . losses . absoluteDifference (
68
69
targets ,
69
70
features . gather ( [ timeSteps - 1 ] , 1 ) . gather ( [ 1 ] , 2 ) . squeeze ( [ 2 ] ) ) ) ) ;
70
71
} ) ;
72
+
71
73
const meanAbsoluteError = tf . tidy ( ( ) => {
72
74
const batchSizesTensor = tf . tensor1d ( batchSizes ) ;
73
75
const batchMeanAbsoluteErrorsTensor = tf . stack ( batchMeanAbsoluteErrors ) ;
@@ -83,7 +85,7 @@ export async function getBaselineMeanAbsoluteError(
83
85
* Build a linear-regression model for the temperature-prediction problem.
84
86
*
85
87
* @param {tf.Shape } inputShape Input shape (without the batch dimenson).
86
- * @returns {tf.Model } A TensorFlow.js tf.Model instance.
88
+ * @returns {tf.LayersModel } A TensorFlow.js tf.LayersModel instance.
87
89
*/
88
90
function buildLinearRegressionModel ( inputShape ) {
89
91
const model = tf . sequential ( ) ;
@@ -102,9 +104,9 @@ function buildLinearRegressionModel(inputShape) {
102
104
* @param {number } dropoutRate Dropout rate of an optional dropout layer
103
105
* inserted between the two dense layers of the MLP. Optional. If not
104
106
* specified, no dropout layers will be included in the MLP.
105
- * @returns {tf.Model } A TensorFlow.js tf.Model instance.
107
+ * @returns {tf.LayersModel } A TensorFlow.js tf.LayersModel instance.
106
108
*/
107
- function buildMLPModel ( inputShape , kernelRegularizer , dropoutRate ) {
109
+ export function buildMLPModel ( inputShape , kernelRegularizer , dropoutRate ) {
108
110
const model = tf . sequential ( ) ;
109
111
model . add ( tf . layers . flatten ( { inputShape} ) ) ;
110
112
model . add (
@@ -120,9 +122,10 @@ function buildMLPModel(inputShape, kernelRegularizer, dropoutRate) {
120
122
* Build a simpleRNN-based model for the temperature-prediction problem.
121
123
*
122
124
* @param {tf.Shape } inputShape Input shape (without the batch dimenson).
123
- * @returns {tf.Model } A TensorFlow.js model consisting of a simpleRNN layer.
125
+ * @returns {tf.LayersModel } A TensorFlow.js model consisting of a simpleRNN
126
+ * layer.
124
127
*/
125
- function buildSimpleRNNModel ( inputShape ) {
128
+ export function buildSimpleRNNModel ( inputShape ) {
126
129
const model = tf . sequential ( ) ;
127
130
const rnnUnits = 32 ;
128
131
model . add ( tf . layers . simpleRNN ( {
@@ -139,9 +142,9 @@ function buildSimpleRNNModel(inputShape) {
139
142
* @param {tf.Shape } inputShape Input shape (without the batch dimenson).
140
143
* @param {number } dropout Optional input dropout rate
141
144
* @param {number } recurrentDropout Optional recurrent dropout rate.
142
- * @returns {tf.Model } A TensorFlow.js GRU model.
145
+ * @returns {tf.LayersModel } A TensorFlow.js GRU model.
143
146
*/
144
- function buildGRUModel ( inputShape , dropout , recurrentDropout ) {
147
+ export function buildGRUModel ( inputShape , dropout , recurrentDropout ) {
145
148
// TODO(cais): Recurrent dropout is currently not fully working.
146
149
// Make it work and add a flag to train-rnn.js.
147
150
const model = tf . sequential ( ) ;
@@ -163,7 +166,7 @@ function buildGRUModel(inputShape, dropout, recurrentDropout) {
163
166
* @param {number } numTimeSteps Number of time steps in each input.
164
167
* exapmle
165
168
* @param {number } numFeatures Number of features (for each time step).
166
- * @returns A compiled instance of `tf.Model `.
169
+ * @returns A compiled instance of `tf.LayersModel `.
167
170
*/
168
171
export function buildModel ( modelType , numTimeSteps , numFeatures ) {
169
172
const inputShape = [ numTimeSteps , numFeatures ] ;
@@ -197,9 +200,9 @@ export function buildModel(modelType, numTimeSteps, numFeatures) {
197
200
/**
198
201
* Train a model on the Jena weather data.
199
202
*
200
- * @param {tf.Model } model A compiled tf.Model object. It is expected to
201
- * have a 3D input shape `[numExamples, timeSteps, numFeatures].` and an
202
- * output shape `[numExamples, 1]` for predicting the temperature value.
203
+ * @param {tf.LayersModel } model A compiled tf.LayersModel object. It is
204
+ * expected to have a 3D input shape `[numExamples, timeSteps, numFeatures].`
205
+ * and an output shape `[numExamples, 1]` for predicting the temperature value.
203
206
* @param {JenaWeatherData } jenaWeatherData A JenaWeatherData object.
204
207
* @param {boolean } normalize Whether to used normalized data for training.
205
208
* @param {boolean } includeDateTime Whether to include date and time features
0 commit comments