-
Notifications
You must be signed in to change notification settings - Fork 351
/
Copy pathpaddle_trt_infer.py
450 lines (409 loc) · 16 KB
/
paddle_trt_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import argparse
from tqdm import tqdm
import pkg_resources as pkg
import time
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from dataset import COCOValDataset
from post_process import YOLOPostProcess, coco_metric
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--model_path', type=str, help="inference model filepath")
parser.add_argument(
'--image_file',
type=str,
default=None,
help="image path, if set image_file, it will not eval coco.")
parser.add_argument(
'--dataset_dir',
type=str,
default='dataset/coco',
help="COCO dataset dir.")
parser.add_argument(
'--val_image_dir',
type=str,
default='val2017',
help="COCO dataset val image dir.")
parser.add_argument(
'--val_anno_path',
type=str,
default='annotations/instances_val2017.json',
help="COCO dataset anno path.")
parser.add_argument(
'--benchmark',
type=bool,
default=False,
help="Whether run benchmark or not.")
parser.add_argument(
'--use_dynamic_shape',
type=bool,
default=True,
help="Whether use dynamic shape or not.")
parser.add_argument(
'--run_mode',
type=str,
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
'--device',
type=str,
default='GPU',
help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU"
)
parser.add_argument(
'--arch', type=str, default='YOLOv5', help="architectures name.")
parser.add_argument('--img_shape', type=int, default=640, help="input_size")
parser.add_argument(
'--batch_size', type=int, default=1, help="Batch size of model input.")
return parser
CLASS_LABEL = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'
]
def preprocess(image, input_size, mean=None, std=None, swap=(2, 0, 1)):
if len(image.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0
else:
padded_img = np.ones(input_size) * 114.0
img = np.array(image)
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR, ).astype(np.float32)
padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img
padded_img = padded_img[:, :, ::-1]
padded_img /= 255.0
if mean is not None:
padded_img -= mean
if std is not None:
padded_img /= std
padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r
def get_color_map_list(num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def draw_box(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
color_list = get_color_map_list(len(class_names))
for i in range(len(boxes)):
box = boxes[i]
cls_id = int(cls_ids[i])
color = tuple(color_list[cls_id])
score = scores[i]
if score < conf:
continue
x0 = int(box[0])
y0 = int(box[1])
x1 = int(box[2])
y1 = int(box[3])
text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
font = cv2.FONT_HERSHEY_SIMPLEX
txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
cv2.rectangle(img, (x0, y0 + 1),
(x0 + txt_size[0] + 1, y0 + int(1.5 * txt_size[1])),
color, -1)
cv2.putText(
img,
text, (x0, y0 + txt_size[1]),
font,
0.8, (0, 255, 0),
thickness=2)
return img
def get_current_memory_mb():
"""
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
And this function Current program is time-consuming.
"""
try:
pkg.require('pynvml')
except:
from pip._internal import main
main(['install', 'pynvml'])
try:
pkg.require('psutil')
except:
from pip._internal import main
main(['install', 'psutil'])
try:
pkg.require('GPUtil')
except:
from pip._internal import main
main(['install', 'GPUtil'])
import pynvml
import psutil
import GPUtil
gpu_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0))
pid = os.getpid()
p = psutil.Process(pid)
info = p.memory_full_info()
cpu_mem = info.uss / 1024. / 1024.
gpu_mem = 0
gpu_percent = 0
gpus = GPUtil.getGPUs()
if gpu_id is not None and len(gpus) > 0:
gpu_percent = gpus[gpu_id].load
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem = meminfo.used / 1024. / 1024.
return round(cpu_mem, 4), round(gpu_mem, 4)
def load_predictor(model_dir,
run_mode='paddle',
batch_size=1,
device='CPU',
min_subgraph_size=3,
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
rerun_flag = False
if device != 'GPU' and run_mode != 'paddle':
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.format(run_mode, device))
config = Config(
os.path.join(model_dir, 'model.pdmodel'),
os.path.join(model_dir, 'model.pdiparams'))
if device == 'GPU':
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
elif device == 'XPU':
config.enable_lite_engine()
config.enable_xpu(10 * 1024 * 1024)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)
if enable_mkldnn:
try:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if enable_mkldnn_bfloat16:
config.enable_mkldnn_bfloat16()
except Exception as e:
print(
"The current environment does not support `mkldnn`, so disable mkldnn."
)
pass
precision_map = {
'trt_int8': Config.Precision.Int8,
'trt_fp32': Config.Precision.Float32,
'trt_fp16': Config.Precision.Half
}
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=(1 << 25) * batch_size,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if use_dynamic_shape:
dynamic_shape_file = os.path.join(FLAGS.model_path,
'dynamic_shape.txt')
if os.path.exists(dynamic_shape_file):
config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file,
True)
print('trt set dynamic shape done!')
else:
config.collect_shape_range_info(dynamic_shape_file)
print('Start collect dynamic shape...')
rerun_flag = True
# enable shared memory
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
predictor = create_predictor(config)
return predictor, rerun_flag
def eval(predictor, val_loader, anno_file, rerun_flag=False):
bboxes_list, bbox_nums_list, image_id_list = [], [], []
cpu_mems, gpu_mems = 0, 0
sample_nums = len(val_loader)
with tqdm(
total=sample_nums,
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in val_loader:
data_all = {k: np.array(v) for k, v in data.items()}
inputs = {}
if FLAGS.arch == 'YOLOv6':
inputs['x2paddle_image_arrays'] = data_all['image']
else:
inputs['x2paddle_images'] = data_all['image']
input_names = predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
predictor.run()
output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0])
outs = boxes_tensor.copy_to_cpu()
if rerun_flag:
return
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np.array(outs), data_all['scale_factor'])
bboxes_list.append(res['bbox'])
bbox_nums_list.append(res['bbox_num'])
image_id_list.append(np.array(data_all['im_id']))
cpu_mem, gpu_mem = get_current_memory_mb()
cpu_mems += cpu_mem
gpu_mems += gpu_mem
t.update()
print('Avg cpu_mem:{} MB, avg gpu_mem: {} MB'.format(
cpu_mems / sample_nums, gpu_mems / sample_nums))
coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list)
def infer(predictor):
warmup, repeats = 1, 1
if FLAGS.benchmark:
warmup, repeats = 50, 100
origin_img = cv2.imread(FLAGS.image_file)
input_image, scale_factor = preprocess(origin_img,
[FLAGS.img_shape, FLAGS.img_shape])
input_image = np.expand_dims(input_image, axis=0)
scale_factor = np.array([[scale_factor, scale_factor]])
inputs = {}
if FLAGS.arch == 'YOLOv6':
inputs['x2paddle_image_arrays'] = input_image
else:
inputs['x2paddle_images'] = input_image
input_names = predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
for i in range(warmup):
predictor.run()
np_boxes = None
predict_time = 0.
time_min = float("inf")
time_max = float('-inf')
cpu_mems, gpu_mems = 0, 0
for i in range(repeats):
start_time = time.time()
predictor.run()
output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
end_time = time.time()
timed = end_time - start_time
time_min = min(time_min, timed)
time_max = max(time_max, timed)
predict_time += timed
cpu_mem, gpu_mem = get_current_memory_mb()
cpu_mems += cpu_mem
gpu_mems += gpu_mem
print('Avg cpu_mem:{} MB, avg gpu_mem: {} MB'.format(cpu_mems / repeats,
gpu_mems / repeats))
time_avg = predict_time / repeats
print('Inference time(ms): min={}, max={}, avg={}'.format(
round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1)))
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np_boxes, scale_factor)
# Draw rectangles and labels on the original image
dets = res['bbox']
if dets is not None:
final_boxes, final_scores, final_class = dets[:, 2:], dets[:,
1], dets[:,
0]
res_img = draw_box(
origin_img,
final_boxes,
final_scores,
final_class,
conf=0.5,
class_names=CLASS_LABEL)
cv2.imwrite('output.jpg', res_img)
print('The prediction results are saved in output.jpg.')
def main():
predictor, rerun_flag = load_predictor(
FLAGS.model_path,
run_mode=FLAGS.run_mode,
device=FLAGS.device,
use_dynamic_shape=FLAGS.use_dynamic_shape)
if FLAGS.image_file:
infer(predictor)
else:
dataset = COCOValDataset(
dataset_dir=FLAGS.dataset_dir,
image_dir=FLAGS.val_image_dir,
anno_path=FLAGS.val_anno_path)
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(
dataset, batch_size=FLAGS.batch_size, drop_last=True)
eval(predictor, val_loader, anno_file, rerun_flag=rerun_flag)
if rerun_flag:
print(
"***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
# DataLoader need run on cpu
paddle.set_device('cpu')
main()