-
Notifications
You must be signed in to change notification settings - Fork 232
/
Copy pathgluon_base_service.py
138 lines (116 loc) · 4.69 KB
/
gluon_base_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# http://www.apache.org/licenses/LICENSE-2.0
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Gluon Base service defines a Gluon base service for generic CNN
"""
import mxnet as mx
import numpy as np
import os
import json
import ndarray
class GluonBaseService(object):
"""GluonBaseService defines a fundamental service for image classification task.
In preprocess, input image buffer is read to NDArray and resized respect to input
shape in signature.
In post process, top-5 labels are returned.
"""
def __init__(self):
self.param_filename = None
self.model_name = None
self.initialized = False
self.ctx = None
self.net = None
self._signature = None
self.labels = None
self.signature = None
def initialize(self, params):
"""
Initialization of the network
:param params: This is the :func `Context` object
:return:
"""
if self.net is None:
raise NotImplementedError("Gluon network not defined")
sys_prop = params.system_properties
gpu_id = sys_prop.get("gpu_id")
model_dir = sys_prop.get("model_dir")
self.model_name = params.manifest["model"]["modelName"]
self.ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id)
if self.param_filename is not None:
param_file_path = os.path.join(model_dir, self.param_filename)
if not os.path.isfile(param_file_path):
raise OSError("Parameter file not found {}".format(param_file_path))
self.net.load_parameters(param_file_path, self.ctx)
synset_file = os.path.join(model_dir, "synset.txt")
signature_file_path = os.path.join(model_dir, "signature.json")
if not os.path.isfile(signature_file_path):
raise OSError("Signature file not found {}".format(signature_file_path))
if not os.path.isfile(synset_file):
raise OSError("synset file not available {}".format(synset_file))
with open(signature_file_path) as sig_file:
self.signature = json.load(sig_file)
self.labels = [line.strip() for line in open(synset_file).readlines()]
self.initialized = True
def preprocess(self, data):
"""
This method considers only one input data
:param data: Data is list of map
format is
[
{
"parameterName": name
"parameterValue": data
},
{...}
]
:return:
"""
param_name = self.signature['inputs'][0]['data_name']
input_shape = self.signature['inputs'][0]['data_shape']
img = data[0].get(param_name)
if img is None:
raise IOError("Invalid parameter given")
# We are assuming input shape is NCHW
[h, w] = input_shape[2:]
img_arr = mx.img.imdecode(img)
img_arr = mx.image.imresize(img_arr, w, h)
img_arr = img_arr.astype(np.float32)
img_arr /= 255
img_arr = mx.image.color_normalize(img_arr,
mean=mx.nd.array([0.485, 0.456, 0.406]),
std=mx.nd.array([0.229, 0.224, 0.225]))
img_arr = mx.nd.transpose(img_arr, (2, 0, 1))
img_arr = img_arr.expand_dims(axis=0)
return img_arr
def inference(self, data):
"""
Internal inference methods for MMS service. Run forward computation and
return output.
Parameters
----------
data : list of NDArray
Preprocessed inputs in NDArray format.
Returns
-------
list of NDArray
Inference output.
"""
model_input = data.as_in_context(self.ctx)
output = self.net(model_input)
return output.softmax()
def postprocess(self, data):
assert hasattr(self, 'labels'), \
"Can't find labels attribute. Did you put synset.txt file into " \
"model archive or manually load class label file in __init__?"
return [[ndarray.top_probability(d, self.labels, top=5) for d in data]]
def predict(self, data):
data = self.preprocess(data)
data = self.inference(data)
return self.postprocess(data)