Skip to content

Commit d3aed59

Browse files
committed
update convert tftrt
1 parent f27caf4 commit d3aed59

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

convert_trt.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import tensorflow as tf
21
from absl import app, flags, logging
32
from absl.flags import FLAGS
3+
import tensorflow as tf
4+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
5+
if len(physical_devices) > 0:
6+
tf.config.experimental.set_memory_growth(physical_devices[0], True)
47
import numpy as np
58
import cv2
69
from tensorflow.python.compiler.tensorrt import trt_convert as trt
@@ -14,47 +17,53 @@
1417
flags.DEFINE_string('output', './checkpoints/yolov4-trt-fp16-416', 'path to output')
1518
flags.DEFINE_integer('input_size', 416, 'path to output')
1619
flags.DEFINE_string('quantize_mode', 'float16', 'quantize mode (int8, float16)')
17-
flags.DEFINE_string('dataset', "./coco_dataset/coco/5k.txt", 'path to dataset')
18-
flags.DEFINE_integer('loop', 10, 'loop')
20+
flags.DEFINE_string('dataset', "/media/user/Source/Data/coco_dataset/coco/5k.txt", 'path to dataset')
21+
flags.DEFINE_integer('loop', 8, 'loop')
1922

2023
def representative_data_gen():
2124
fimage = open(FLAGS.dataset).read().split()
25+
batched_input = np.zeros((FLAGS.loop, FLAGS.input_size, FLAGS.input_size, 3), dtype=np.float32)
2226
for input_value in range(FLAGS.loop):
2327
if os.path.exists(fimage[input_value]):
2428
original_image=cv2.imread(fimage[input_value])
2529
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
26-
image_data = utils.image_preprocess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size])
30+
image_data = utils.image_preporcess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size])
2731
img_in = image_data[np.newaxis, ...].astype(np.float32)
28-
batched_input = tf.constant(img_in)
32+
batched_input[input_value, :] = img_in
33+
# batched_input = tf.constant(img_in)
2934
print(input_value)
30-
yield (batched_input, )
35+
# yield (batched_input, )
36+
# yield tf.random.normal((1, 416, 416, 3)),
3137
else:
3238
continue
39+
batched_input = tf.constant(batched_input)
40+
yield (batched_input,)
3341

3442
def save_trt():
43+
3544
if FLAGS.quantize_mode == 'int8':
3645
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
3746
precision_mode=trt.TrtPrecisionMode.INT8,
38-
max_workspace_size_bytes=8000000000,
47+
max_workspace_size_bytes=4000000000,
3948
use_calibration=True,
40-
max_batch_size=32)
49+
max_batch_size=8)
4150
converter = trt.TrtGraphConverterV2(
4251
input_saved_model_dir=FLAGS.weights,
4352
conversion_params=conversion_params)
4453
converter.convert(calibration_input_fn=representative_data_gen)
4554
elif FLAGS.quantize_mode == 'float16':
4655
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
4756
precision_mode=trt.TrtPrecisionMode.FP16,
48-
max_workspace_size_bytes=8000000000,
49-
max_batch_size=32)
57+
max_workspace_size_bytes=4000000000,
58+
max_batch_size=8)
5059
converter = trt.TrtGraphConverterV2(
5160
input_saved_model_dir=FLAGS.weights, conversion_params=conversion_params)
5261
converter.convert()
5362
else :
5463
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
5564
precision_mode=trt.TrtPrecisionMode.FP32,
56-
max_workspace_size_bytes=8000000000,
57-
max_batch_size=32)
65+
max_workspace_size_bytes=4000000000,
66+
max_batch_size=8)
5867
converter = trt.TrtGraphConverterV2(
5968
input_saved_model_dir=FLAGS.weights, conversion_params=conversion_params)
6069
converter.convert()

0 commit comments

Comments
 (0)