Skip to content

Commit 471429d

Browse files
author
Nikhil Thorat
authored
Add linting rules for tfjs-models. (#333)
- Adds lint rules to mirror monorepo - Fixes lint errors - Adds CI to run lint - Refactors body pix and posenet to simplify base models for resnet / mobilenet.
1 parent d101c73 commit 471429d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1232
-1542
lines changed

Diff for: .tslint/noImportsFromDistRule.js

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
'use strict';
2+
var __extends = (this && this.__extends) || (function() {
3+
var extendStatics = function(d, b) {
4+
extendStatics = Object.setPrototypeOf ||
5+
({__proto__: []} instanceof Array && function(d, b) {
6+
d.__proto__ = b;
7+
}) || function(d, b) {
8+
for (var p in b)
9+
if (b.hasOwnProperty(p)) d[p] = b[p];
10+
};
11+
return extendStatics(d, b);
12+
};
13+
return function(d, b) {
14+
extendStatics(d, b);
15+
function __() {
16+
this.constructor = d;
17+
}
18+
d.prototype = b === null ?
19+
Object.create(b) :
20+
(__.prototype = b.prototype, new __());
21+
};
22+
})();
23+
exports.__esModule = true;
24+
var Lint = require('tslint');
25+
var Rule = /** @class */ (function(_super) {
26+
__extends(Rule, _super);
27+
function Rule() {
28+
return _super !== null && _super.apply(this, arguments) || this;
29+
}
30+
Rule.prototype.apply = function(sourceFile) {
31+
return this.applyWithWalker(
32+
new NoImportsFromDistWalker(sourceFile, this.getOptions()));
33+
};
34+
Rule.FAILURE_STRING =
35+
'importing from dist/ is prohibited. Please use public API';
36+
return Rule;
37+
}(Lint.Rules.AbstractRule));
38+
exports.Rule = Rule;
39+
var NoImportsFromDistWalker = /** @class */ (function(_super) {
40+
__extends(NoImportsFromDistWalker, _super);
41+
function NoImportsFromDistWalker() {
42+
return _super !== null && _super.apply(this, arguments) || this;
43+
}
44+
NoImportsFromDistWalker.prototype.visitImportDeclaration = function(node) {
45+
var importFrom = node.moduleSpecifier.getText();
46+
var reg = /@tensorflow\/tfjs[-a-z]*\/dist/;
47+
if (importFrom.match(reg)) {
48+
var fix = new Lint.Replacement(
49+
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(),
50+
importFrom.replace(/\/dist[\/]*/, ''));
51+
this.addFailure(this.createFailure(
52+
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(),
53+
Rule.FAILURE_STRING, fix));
54+
}
55+
_super.prototype.visitImportDeclaration.call(this, node);
56+
};
57+
return NoImportsFromDistWalker;
58+
}(Lint.RuleWalker));

Diff for: .tslint/noImportsFromDistRule.ts

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import * as Lint from 'tslint';
2+
import * as ts from 'typescript';
3+
4+
export class Rule extends Lint.Rules.AbstractRule {
5+
public static FAILURE_STRING =
6+
'importing from dist/ is prohibited. Please use public API';
7+
8+
public apply(sourceFile: ts.SourceFile): Lint.RuleFailure[] {
9+
return this.applyWithWalker(
10+
new NoImportsFromDistWalker(sourceFile, this.getOptions()));
11+
}
12+
}
13+
14+
class NoImportsFromDistWalker extends Lint.RuleWalker {
15+
public visitImportDeclaration(node: ts.ImportDeclaration) {
16+
const importFrom = node.moduleSpecifier.getText();
17+
const reg = /@tensorflow\/tfjs[-a-z]*\/dist/;
18+
if (importFrom.match(reg)) {
19+
const fix = new Lint.Replacement(
20+
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(),
21+
importFrom.replace(/\/dist[\/]*/, ''));
22+
23+
this.addFailure(this.createFailure(
24+
node.moduleSpecifier.getStart(), node.moduleSpecifier.getWidth(),
25+
Rule.FAILURE_STRING, fix));
26+
}
27+
28+
super.visitImportDeclaration(node);
29+
}
30+
}

Diff for: body-pix/cloudbuild.yml

+8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ steps:
1414
args: ['install']
1515
waitFor: ['yarn-common']
1616

17+
# Lint.
18+
- name: 'node:10'
19+
dir: 'body-pix'
20+
entrypoint: 'yarn'
21+
id: 'lint'
22+
args: ['lint']
23+
waitFor: ['yarn']
24+
1725
# Build.
1826
- name: 'node:10'
1927
dir: 'body-pix'

Diff for: body-pix/demos/package.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
"scripts": {
1919
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open ",
2020
"build": "cross-env NODE_ENV=production parcel build index.html --public-url ./",
21-
"lint": "eslint ."
21+
"lint": "eslint .",
22+
"link-local": "yalc link"
2223
},
2324
"browser": {
2425
"crypto": false

Diff for: body-pix/demos/yarn.lock

-5
Original file line numberDiff line numberDiff line change
@@ -6243,11 +6243,6 @@ typedarray@^0.0.6:
62436243
resolved "https://registry.yarnpkg.com/typedarray/-/typedarray-0.0.6.tgz#867ac74e3864187b1d3d47d996a78ec5c8830777"
62446244
integrity sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=
62456245

6246-
6247-
version "0.0.54"
6248-
resolved "https://registry.yarnpkg.com/typeface-oswald/-/typeface-oswald-0.0.54.tgz#1e253011622cdd50f580c04e7d625e7f449763d7"
6249-
integrity sha512-U1WMNp4qfy4/3khIfHMVAIKnNu941MXUfs3+H9R8PFgnoz42Hh9pboSFztWr86zut0eXC8byalmVhfkiKON/8Q==
6250-
62516246
uncss@^0.17.0:
62526247
version "0.17.2"
62536248
resolved "https://registry.yarnpkg.com/uncss/-/uncss-0.17.2.tgz#fac1c2429be72108e8a47437c647d58cf9ea66f1"

Diff for: body-pix/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"rollup-plugin-typescript2": "~0.13.0",
3535
"rollup-plugin-uglify": "~3.0.0",
3636
"ts-node": "~5.0.0",
37-
"tslint": "~5.8.0",
37+
"tslint": "~5.18.0",
3838
"typescript": "~3.5.3",
3939
"yalc": "^1.0.0-pre.27"
4040
},

Diff for: body-pix/src/base_model.ts

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
2+
/**
3+
* @license
4+
* Copyright 2019 Google Inc. All Rights Reserved.
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
* =============================================================================
17+
*/
18+
19+
import * as tfconv from '@tensorflow/tfjs-converter';
20+
import * as tf from '@tensorflow/tfjs-core';
21+
import {BodyPixOutputStride} from './types';
22+
23+
/**
24+
* BodyPix supports using various convolution neural network models
25+
* (e.g. ResNet and MobileNetV1) as its underlying base model.
26+
* The following BaseModel interface defines a unified interface for
27+
* creating such BodyPix base models. Currently both MobileNet (in
28+
* ./mobilenet.ts) and ResNet (in ./resnet.ts) implements the BaseModel
29+
* interface. New base models that conform to the BaseModel interface can be
30+
* added to BodyPix.
31+
*/
32+
export abstract class BaseModel {
33+
constructor(
34+
protected readonly model: tfconv.GraphModel,
35+
public readonly outputStride: BodyPixOutputStride) {
36+
const inputShape =
37+
this.model.inputs[0].shape as [number, number, number, number];
38+
tf.util.assert(
39+
(inputShape[1] === -1) && (inputShape[2] === -1),
40+
() => `Input shape [${inputShape[1]}, ${inputShape[2]}] ` +
41+
`must both be equal to or -1`);
42+
}
43+
44+
abstract preprocessInput(input: tf.Tensor3D): tf.Tensor3D;
45+
46+
/**
47+
* Predicts intermediate Tensor representations.
48+
*
49+
* @param input The input RGB image of the base model.
50+
* A Tensor of shape: [`inputResolution`, `inputResolution`, 3].
51+
*
52+
* @return A dictionary of base model's intermediate predictions.
53+
* The returned dictionary should contains the following elements:
54+
* - heatmapScores: A Tensor3D that represents the keypoint heatmap scores.
55+
* - offsets: A Tensor3D that represents the offsets.
56+
* - displacementFwd: A Tensor3D that represents the forward displacement.
57+
* - displacementBwd: A Tensor3D that represents the backward displacement.
58+
* - segmentation: A Tensor3D that represents the segmentation of all
59+
* people.
60+
* - longOffsets: A Tensor3D that represents the long offsets used for
61+
* instance grouping.
62+
* - partHeatmaps: A Tensor3D that represents the body part segmentation.
63+
*/
64+
predict(input: tf.Tensor3D): {
65+
heatmapScores: tf.Tensor3D,
66+
offsets: tf.Tensor3D,
67+
displacementFwd: tf.Tensor3D,
68+
displacementBwd: tf.Tensor3D,
69+
segmentation: tf.Tensor3D,
70+
partHeatmaps: tf.Tensor3D,
71+
longOffsets: tf.Tensor3D,
72+
partOffsets: tf.Tensor3D
73+
} {
74+
return tf.tidy(() => {
75+
const asFloat = this.preprocessInput(input.toFloat());
76+
const asBatch = asFloat.expandDims(0);
77+
const results = this.model.predict(asBatch) as tf.Tensor4D[];
78+
const results3d: tf.Tensor3D[] = results.map(y => y.squeeze([0]));
79+
const namedResults = this.nameOutputResults(results3d);
80+
81+
return {
82+
heatmapScores: namedResults.heatmap.sigmoid(),
83+
offsets: namedResults.offsets,
84+
displacementFwd: namedResults.displacementFwd,
85+
displacementBwd: namedResults.displacementBwd,
86+
segmentation: namedResults.segmentation,
87+
partHeatmaps: namedResults.partHeatmaps,
88+
longOffsets: namedResults.longOffsets,
89+
partOffsets: namedResults.partOffsets
90+
};
91+
});
92+
}
93+
94+
// Because MobileNet and ResNet predict() methods output a different order for
95+
// these values, we have a method that needs to be implemented to order them.
96+
abstract nameOutputResults(results: tf.Tensor3D[]): {
97+
heatmap: tf.Tensor3D,
98+
offsets: tf.Tensor3D,
99+
displacementFwd: tf.Tensor3D,
100+
displacementBwd: tf.Tensor3D,
101+
segmentation: tf.Tensor3D,
102+
partHeatmaps: tf.Tensor3D,
103+
longOffsets: tf.Tensor3D,
104+
partOffsets: tf.Tensor3D
105+
};
106+
107+
/**
108+
* Releases the CPU and GPU memory allocated by the model.
109+
*/
110+
dispose() {
111+
this.model.dispose();
112+
}
113+
}

Diff for: body-pix/src/body_pix_model.ts

+14-51
Original file line numberDiff line numberDiff line change
@@ -19,56 +19,19 @@
1919
import * as tfconv from '@tensorflow/tfjs-converter';
2020
import * as tf from '@tensorflow/tfjs-core';
2121

22+
import {BaseModel} from './base_model';
2223
import {decodeOnlyPartSegmentation, decodePartSegmentation, toMaskTensor} from './decode_part_map';
23-
import {MobileNet, MobileNetMultiplier} from './mobilenet';
24+
import {MobileNet} from './mobilenet';
2425
import {decodePersonInstanceMasks, decodePersonInstancePartMasks} from './multi_person/decode_instance_masks';
2526
import {decodeMultiplePoses} from './multi_person/decode_multiple_poses';
2627
import {ResNet} from './resnet';
2728
import {mobileNetSavedModel, resNet50SavedModel} from './saved_models';
28-
import {decodeSinglePose} from './sinlge_person/decode_single_pose';
29+
import {decodeSinglePose} from './single_person/decode_single_pose';
2930
import {BodyPixArchitecture, BodyPixInput, BodyPixInternalResolution, BodyPixMultiplier, BodyPixOutputStride, BodyPixQuantBytes, Padding, PartSegmentation, PersonSegmentation} from './types';
3031
import {getInputSize, padAndResizeTo, scaleAndCropToInputTensorShape, scaleAndFlipPoses, toTensorBuffers3D, toValidInternalResolutionNumber} from './util';
3132

32-
3333
const APPLY_SIGMOID_ACTIVATION = true;
3434

35-
/**
36-
* BodyPix supports using various convolution neural network models
37-
* (e.g. ResNet and MobileNetV1) as its underlying base model.
38-
* The following BaseModel interface defines a unified interface for
39-
* creating such BodyPix base models. Currently both MobileNet (in
40-
* ./mobilenet.ts) and ResNet (in ./resnet.ts) implements the BaseModel
41-
* interface. New base models that conform to the BaseModel interface can be
42-
* added to BodyPix.
43-
*/
44-
export interface BaseModel {
45-
// The output stride of the base model.
46-
readonly outputStride: BodyPixOutputStride;
47-
48-
/**
49-
* Predicts intermediate Tensor representations.
50-
*
51-
* @param input The input RGB image of the base model.
52-
* A Tensor of shape: [`inputResolution`, `inputResolution`, 3].
53-
*
54-
* @return A dictionary of base model's intermediate predictions.
55-
* The returned dictionary should contains the following elements:
56-
* - heatmapScores: A Tensor3D that represents the keypoint heatmap scores.
57-
* - offsets: A Tensor3D that represents the offsets.
58-
* - displacementFwd: A Tensor3D that represents the forward displacement.
59-
* - displacementBwd: A Tensor3D that represents the backward displacement.
60-
* - segmentation: A Tensor3D that represents the segmentation of all people.
61-
* - longOffsets: A Tensor3D that represents the long offsets used for
62-
* instance grouping.
63-
* - partHeatmaps: A Tensor3D that represents the body part segmentation.
64-
*/
65-
predict(input: tf.Tensor3D): {[key: string]: tf.Tensor3D};
66-
/**
67-
* Releases the CPU and GPU memory allocated by the model.
68-
*/
69-
dispose(): void;
70-
}
71-
7235
/**
7336
* BodyPix model loading is configurable using the following config dictionary.
7437
*
@@ -101,7 +64,7 @@ export interface BaseModel {
10164
export interface ModelConfig {
10265
architecture: BodyPixArchitecture;
10366
outputStride: BodyPixOutputStride;
104-
multiplier?: MobileNetMultiplier;
67+
multiplier?: BodyPixMultiplier;
10568
modelUrl?: string;
10669
quantBytes?: BodyPixQuantBytes;
10770
}
@@ -602,15 +565,15 @@ export class BodyPix {
602565
};
603566
});
604567

605-
const [scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer] =
568+
const [scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf] =
606569
await toTensorBuffers3D([
607570
heatmapScoresRaw, offsetsRaw, displacementFwdRaw, displacementBwdRaw
608571
]);
609572

610-
let poses = await decodeMultiplePoses(
611-
scoresBuffer, offsetsBuffer, displacementsFwdBuffer,
612-
displacementsBwdBuffer, this.baseModel.outputStride,
613-
config.maxDetections, config.scoreThreshold, config.nmsRadius);
573+
let poses = decodeMultiplePoses(
574+
scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf,
575+
this.baseModel.outputStride, config.maxDetections,
576+
config.scoreThreshold, config.nmsRadius);
614577

615578
poses = scaleAndFlipPoses(
616579
poses, [height, width],
@@ -849,15 +812,15 @@ export class BodyPix {
849812
};
850813
});
851814

852-
const [scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer] =
815+
const [scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf] =
853816
await toTensorBuffers3D([
854817
heatmapScoresRaw, offsetsRaw, displacementFwdRaw, displacementBwdRaw
855818
]);
856819

857-
let poses = await decodeMultiplePoses(
858-
scoresBuffer, offsetsBuffer, displacementsFwdBuffer,
859-
displacementsBwdBuffer, this.baseModel.outputStride,
860-
config.maxDetections, config.scoreThreshold, config.nmsRadius);
820+
let poses = decodeMultiplePoses(
821+
scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf,
822+
this.baseModel.outputStride, config.maxDetections,
823+
config.scoreThreshold, config.nmsRadius);
861824

862825
poses = scaleAndFlipPoses(
863826
poses, [height, width],

0 commit comments

Comments
 (0)