Skip to content

Commit 00ab7f6

Browse files
authoredMar 15, 2019
[sentiment] Add unit tests (tensorflow#250)
1 parent 0c33845 commit 00ab7f6

11 files changed

+384
-6
lines changed
 

‎sentiment/README.md

+10
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,13 @@ These files can be directly uploaded to the Embedding Projector
9999

100100
See example screenshot:
101101
![image](https://user-images.githubusercontent.com/16824702/52145038-f0fce480-262d-11e9-9313-9a5014ace25f.png)
102+
103+
### Running unit tests
104+
105+
This example comes with unit tests. If you would like to submit changes to the code,
106+
be sure to run the tests and ensure they pass first:
107+
108+
```sh
109+
yarn
110+
yarn test
111+
```

‎sentiment/data.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import * as https from 'https';
2121
import * as os from 'os';
2222
import * as path from 'path';
2323

24-
import {OOV_INDEX, PAD_INDEX, padSequences} from './sequence_utils';
24+
import {OOV_INDEX, padSequences} from './sequence_utils';
2525

2626
// `import` doesn't seem to work with extract-zip.
2727
const extract = require('extract-zip');
@@ -92,7 +92,7 @@ function loadFeatures(filePath, numWords, maxLen, multihot = false) {
9292
const buffer = tf.buffer([sequences.length, numWords]);
9393
sequences.forEach((seq, i) => {
9494
seq.forEach(wordIndex => {
95-
if (wordIndex !== OOV_CHAR) {
95+
if (wordIndex !== OOV_INDEX) {
9696
buffer.set(1, i, wordIndex);
9797
}
9898
});

‎sentiment/data_test.js

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {loadData} from "./data";
19+
20+
describe('loadData', () => {
21+
it('multihot = false', async () => {
22+
const numWords = 10;
23+
const len = 5;
24+
const multihot = false;
25+
const data = await loadData(numWords, len, multihot);
26+
27+
expect(data.xTrain.shape.length).toEqual(2);
28+
expect(data.yTrain.shape.length).toEqual(2);
29+
expect(data.xTrain.shape[0]).toEqual(data.yTrain.shape[0]);
30+
expect(data.xTrain.shape[1]).toEqual(5);
31+
expect(data.yTrain.shape[1]).toEqual(1);
32+
expect(data.xTest.shape.length).toEqual(2);
33+
expect(data.yTest.shape.length).toEqual(2);
34+
expect(data.xTest.shape[0]).toEqual(data.yTest.shape[0]);
35+
expect(data.xTest.shape[1]).toEqual(5);
36+
expect(data.yTest.shape[1]).toEqual(1);
37+
});
38+
39+
it('multihot = true', async () => {
40+
const numWords = 10;
41+
const len = 5;
42+
const multihot = true;
43+
const data = await loadData(numWords, len, multihot);
44+
45+
expect(data.xTrain.shape.length).toEqual(2);
46+
expect(data.yTrain.shape.length).toEqual(2);
47+
expect(data.xTrain.shape[0]).toEqual(data.yTrain.shape[0]);
48+
expect(data.xTrain.shape[1]).toEqual(10);
49+
expect(data.yTrain.shape[1]).toEqual(1);
50+
expect(data.xTest.shape.length).toEqual(2);
51+
expect(data.yTest.shape.length).toEqual(2);
52+
expect(data.xTest.shape[0]).toEqual(data.yTest.shape[0]);
53+
expect(data.xTest.shape[1]).toEqual(10);
54+
expect(data.yTest.shape[1]).toEqual(1);
55+
});
56+
});
57+

‎sentiment/embedding_test.js

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as fs from 'fs';
19+
20+
import * as tf from '@tensorflow/tfjs-node';
21+
import * as shelljs from 'shelljs';
22+
import * as tmp from 'tmp';
23+
24+
import {buildModel} from "./train";
25+
import {writeEmbeddingMatrixAndLabels} from "./embedding";
26+
27+
describe('writeEmbeddingMatrixAndLabels', () => {
28+
let tmpDir;
29+
30+
beforeEach(() => {
31+
tmpDir = tmp.dirSync().name;
32+
});
33+
34+
afterEach(() => {
35+
if (fs.existsSync(tmpDir)) {
36+
shelljs.rm('-rf', tmpDir);
37+
}
38+
});
39+
40+
it('writeEmbeddingMatrixAndLabels', async () => {
41+
const maxLen = 5;
42+
const vocabSize = 4;
43+
const embeddingSize = 8;
44+
const model = buildModel('cnn', maxLen, vocabSize, embeddingSize);
45+
expect(model.inputs.length).toEqual(1);
46+
expect(model.inputs[0].shape).toEqual([null, maxLen]);
47+
expect(model.outputs.length).toEqual(1);
48+
expect(model.outputs[0].shape).toEqual([null, 1]);
49+
50+
model.compile({
51+
loss: 'binaryCrossentropy',
52+
optimizer: 'rmsprop',
53+
metrics: ['acc']
54+
});
55+
const xs = tf.ones([2, maxLen])
56+
const ys = tf.ones([2, 1]);
57+
const history = await model.fit(xs, ys, {epochs: 2, batchSize: 2});
58+
expect(history.history.loss.length).toEqual(2);
59+
expect(history.history.acc.length).toEqual(2);
60+
61+
const predictOuts = model.predict(xs);
62+
expect(predictOuts.shape).toEqual([2, 1]);
63+
const values = predictOuts.arraySync();
64+
expect(values[0][0]).toBeGreaterThanOrEqual(0);
65+
expect(values[0][0]).toBeLessThanOrEqual(1);
66+
expect(values[1][0]).toBeGreaterThanOrEqual(0);
67+
expect(values[1][0]).toBeLessThanOrEqual(1);
68+
69+
const wordIndex = {
70+
'foo': 1,
71+
'bar': 2,
72+
'baz': 3,
73+
'qux': 4
74+
};
75+
await writeEmbeddingMatrixAndLabels(model, `${tmpDir}/embed`, wordIndex, 0);
76+
const vectorFileContent =
77+
fs.readFileSync(`${tmpDir}/embed_vectors.tsv`, {encoding: 'utf-8'})
78+
.trim().split('\n');
79+
expect(vectorFileContent.length).toEqual(4);
80+
const labelsFileContent =
81+
fs.readFileSync(`${tmpDir}/embed_labels.tsv`, {encoding: 'utf-8'})
82+
.trim().split('\n');
83+
expect(labelsFileContent.length).toEqual(4);
84+
});
85+
});

‎sentiment/package.json

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"watch": "./serve.sh",
1717
"build": "cross-env NODE_ENV=production parcel build index.html --no-minify --public-url ./",
1818
"link-local": "yalc link",
19+
"test": "babel-node run_tests.js",
1920
"train": "babel-node train.js"
2021
},
2122
"devDependencies": {
@@ -33,6 +34,7 @@
3334
"http-server": "~0.10.0",
3435
"parcel-bundler": "~1.10.3",
3536
"shelljs": "^0.8.3",
37+
"tmp": "^0.0.33",
3638
"yalc": "~1.0.0-pre.22"
3739
}
3840
}

‎sentiment/run_tests.js

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const jasmine_util = require('@tensorflow/tfjs-core/dist/jasmine_util');
19+
const runTests = require('../test_util').runTests;
20+
21+
runTests(jasmine_util, ['./*test.js']);

‎sentiment/sequence_utils_test.js

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {padSequences} from "./sequence_utils";
19+
20+
describe('padSequences', () => {
21+
it('post pad, post truncate', () => {
22+
const sequences = [[10, 20, 30], [5, 15], [], [1, 2, 3, 4, 5, 6]];
23+
const output = padSequences(sequences, 4, 'post', 'post');
24+
expect(output).toEqual(
25+
[[10, 20, 30, 0], [5, 15, 0, 0], [0, 0, 0, 0], [1, 2, 3, 4]]);
26+
});
27+
28+
it('post pad, pre trucnate', () => {
29+
const sequences = [[10, 20, 30], [5, 15], [], [1, 2, 3, 4, 5, 6]];
30+
const output = padSequences(sequences, 4, 'post', 'pre');
31+
expect(output).toEqual(
32+
[[10, 20, 30, 0], [5, 15, 0, 0], [0, 0, 0, 0], [3, 4, 5, 6]]);
33+
});
34+
35+
it('pre pad, post trucnate', () => {
36+
const sequences = [[10, 20, 30], [5, 15], [], [1, 2, 3, 4, 5, 6]];
37+
const output = padSequences(sequences, 4, 'pre', 'post');
38+
expect(output).toEqual(
39+
[[0, 10, 20, 30], [0, 0, 5, 15], [0, 0, 0, 0], [1, 2, 3, 4]]);
40+
});
41+
42+
it('pre pad, pre trucnate', () => {
43+
const sequences = [[10, 20, 30], [5, 15], [], [1, 2, 3, 4, 5, 6]];
44+
const output = padSequences(sequences, 4, 'pre', 'pre');
45+
expect(output).toEqual(
46+
[[0, 10, 20, 30], [0, 0, 5, 15], [0, 0, 0, 0], [3, 4, 5, 6]]);
47+
});
48+
49+
it('custom padding character', () => {
50+
const sequences = [[10, 20, 30], [5, 15], [], [1, 2, 3, 4, 5, 6]];
51+
const output = padSequences(sequences, 4, 'pre', 'pre', 42);
52+
expect(output).toEqual(
53+
[[42, 10, 20, 30], [42, 42, 5, 15], [42, 42, 42, 42], [3, 4, 5, 6]]);
54+
});
55+
});

‎sentiment/serve.sh

-1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,3 @@ node_modules/parcel-bundler/bin/cli.js serve index.html -d dist --open --no-hmr
3939

4040
# When the Parcel server exits, kill the http-server too.
4141
kill $HTTP_SERVER_PID
42-

‎sentiment/train.js

+4-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import {writeEmbeddingMatrixAndLabels} from './embedding';
3333
* configure the embedding layer.
3434
* @returns An uncompiled instance of `tf.Model`.
3535
*/
36-
function buildModel(modelType, maxLen, vocabularySize, embeddingSize) {
36+
export function buildModel(modelType, maxLen, vocabularySize, embeddingSize) {
3737
// TODO(cais): Bidirectional and dense-only.
3838
const model = tf.sequential();
3939
if (modelType === 'multihot') {
@@ -227,4 +227,6 @@ async function main() {
227227
}
228228
}
229229

230-
main();
230+
if (require.main === module) {
231+
main();
232+
}

0 commit comments

Comments
 (0)