|
1 |
| -import tensorflow as tf |
2 | 1 | from absl import app, flags, logging
|
3 | 2 | from absl.flags import FLAGS
|
| 3 | +import tensorflow as tf |
| 4 | +physical_devices = tf.config.experimental.list_physical_devices('GPU') |
| 5 | +if len(physical_devices) > 0: |
| 6 | + tf.config.experimental.set_memory_growth(physical_devices[0], True) |
4 | 7 | import numpy as np
|
5 | 8 | import cv2
|
6 | 9 | from tensorflow.python.compiler.tensorrt import trt_convert as trt
|
|
14 | 17 | flags.DEFINE_string('output', './checkpoints/yolov4-trt-fp16-416', 'path to output')
|
15 | 18 | flags.DEFINE_integer('input_size', 416, 'path to output')
|
16 | 19 | flags.DEFINE_string('quantize_mode', 'float16', 'quantize mode (int8, float16)')
|
17 |
| -flags.DEFINE_string('dataset', "./coco_dataset/coco/5k.txt", 'path to dataset') |
18 |
| -flags.DEFINE_integer('loop', 10, 'loop') |
| 20 | +flags.DEFINE_string('dataset', "/media/user/Source/Data/coco_dataset/coco/5k.txt", 'path to dataset') |
| 21 | +flags.DEFINE_integer('loop', 8, 'loop') |
19 | 22 |
|
20 | 23 | def representative_data_gen():
|
21 | 24 | fimage = open(FLAGS.dataset).read().split()
|
| 25 | + batched_input = np.zeros((FLAGS.loop, FLAGS.input_size, FLAGS.input_size, 3), dtype=np.float32) |
22 | 26 | for input_value in range(FLAGS.loop):
|
23 | 27 | if os.path.exists(fimage[input_value]):
|
24 | 28 | original_image=cv2.imread(fimage[input_value])
|
25 | 29 | original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
|
26 |
| - image_data = utils.image_preprocess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size]) |
| 30 | + image_data = utils.image_preporcess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size]) |
27 | 31 | img_in = image_data[np.newaxis, ...].astype(np.float32)
|
28 |
| - batched_input = tf.constant(img_in) |
| 32 | + batched_input[input_value, :] = img_in |
| 33 | + # batched_input = tf.constant(img_in) |
29 | 34 | print(input_value)
|
30 |
| - yield (batched_input, ) |
| 35 | + # yield (batched_input, ) |
| 36 | + # yield tf.random.normal((1, 416, 416, 3)), |
31 | 37 | else:
|
32 | 38 | continue
|
| 39 | + batched_input = tf.constant(batched_input) |
| 40 | + yield (batched_input,) |
33 | 41 |
|
34 | 42 | def save_trt():
|
| 43 | + |
35 | 44 | if FLAGS.quantize_mode == 'int8':
|
36 | 45 | conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
37 | 46 | precision_mode=trt.TrtPrecisionMode.INT8,
|
38 |
| - max_workspace_size_bytes=8000000000, |
| 47 | + max_workspace_size_bytes=4000000000, |
39 | 48 | use_calibration=True,
|
40 |
| - max_batch_size=32) |
| 49 | + max_batch_size=8) |
41 | 50 | converter = trt.TrtGraphConverterV2(
|
42 | 51 | input_saved_model_dir=FLAGS.weights,
|
43 | 52 | conversion_params=conversion_params)
|
44 | 53 | converter.convert(calibration_input_fn=representative_data_gen)
|
45 | 54 | elif FLAGS.quantize_mode == 'float16':
|
46 | 55 | conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
47 | 56 | precision_mode=trt.TrtPrecisionMode.FP16,
|
48 |
| - max_workspace_size_bytes=8000000000, |
49 |
| - max_batch_size=32) |
| 57 | + max_workspace_size_bytes=4000000000, |
| 58 | + max_batch_size=8) |
50 | 59 | converter = trt.TrtGraphConverterV2(
|
51 | 60 | input_saved_model_dir=FLAGS.weights, conversion_params=conversion_params)
|
52 | 61 | converter.convert()
|
53 | 62 | else :
|
54 | 63 | conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
55 | 64 | precision_mode=trt.TrtPrecisionMode.FP32,
|
56 |
| - max_workspace_size_bytes=8000000000, |
57 |
| - max_batch_size=32) |
| 65 | + max_workspace_size_bytes=4000000000, |
| 66 | + max_batch_size=8) |
58 | 67 | converter = trt.TrtGraphConverterV2(
|
59 | 68 | input_saved_model_dir=FLAGS.weights, conversion_params=conversion_params)
|
60 | 69 | converter.convert()
|
|
0 commit comments