-
Notifications
You must be signed in to change notification settings - Fork 7
/
index.js
78 lines (64 loc) · 1.85 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
/**
* Node gRPC client for TensorFlow Serving server.
*/
const grpc = require('grpc');
const PROTO_PATH = __dirname + '/protos/prediction_service.proto';
/**
* Exports client.
*
* @param {string} connection Tensorflow Serving connection string. E.g., localhost:9000
*/
module.exports = (connection) => {
// loading service proto
var tensorflow_serving = grpc.load(PROTO_PATH).tensorflow.serving;
// creating gRPC service client
var client = new tensorflow_serving.PredictionService(
connection, grpc.credentials.createInsecure()
);
return {
/**
* Calls predict gRPC method on TensorFlow Serving server.
*
* @param {Buffer|Array<Buffer>} buffer JPEG data buffer to classify.
* @param {Function} fn Callback.
*/
predict: (buffer, fn) => {
var buffers;
if (buffer.constructor === Array) {
buffers = buffer;
} else {
buffers = [buffer];
}
// building PredictRequest proto message
const msg = {
model_spec: { name: 'inception', signature_name: 'predict_images' },
inputs: {
images: {
dtype: 'DT_STRING',
tensor_shape: {
dim: {
size: buffers.length
}
},
string_val: buffers
}
}
};
// calling gRPC method
client.predict(msg, (err, response) => {
if (err) return fn(err);
// decoding response
const classes = response.outputs.classes.string_val.map((b) => b.toString('utf8'));
// splitting results
var i,
len = classes.length,
chunk = 5,
results = [];
for (i = 0; i < len; i+=chunk) {
results.push(classes.slice(i, i+chunk));
}
fn(null, results)
});
}
}
};