Skip to content

Commit 5f4938f

Browse files
[Update] FaceFusion Speed up (#446)
* update CMakeLists.txt * face_recognizer_postprocess cuda code and wrapper implement * trt code update * face_swap_postprocess cuda code implement * update code * facefusion pipeline test code upate * facefusion pipeline test code upate * fix multi thread face68landmarks code * multi thread yolofacev8 code * update name * bgr2rgb cuda code implement * use cuda rgb2bgr method * update code * speed up paste_back func * update to cuda version paste_back * time test for preprocess postprocess and inference * update code * update code --------- Co-authored-by: DefTruth <[email protected]>
1 parent a7cd9db commit 5f4938f

32 files changed

+1002
-82
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ enable_language(CUDA)
2828
set(LITE_AI_ROOT_DIR ${CMAKE_SOURCE_DIR})
2929

3030
option(ENABLE_TEST "build test examples." OFF)
31-
option(ENABLE_DEBUG_STRING "enable DEBUG string or not" OFF)
31+
option(ENABLE_DEBUG_STRING "enable DEBUG string or not" ON)
3232
option(ENABLE_ONNXRUNTIME "enable ONNXRuntime engine" ON)
3333
option(ENABLE_TENSORRT "enable TensorRT engine" OFF)
3434
option(ENABLE_MNN "enable MNN engine" OFF)

examples/lite/cv/test_lite_facefusion_pipeline.cpp

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
static void test_default()
66
{
77
#ifdef ENABLE_ONNXRUNTIME
8-
std::string face_swap_onnx_path = "../../../examples/hub/onnx/cv/inswapper_128.onnx";
9-
std::string face_detect_onnx_path = "../../../examples/hub/onnx/cv/yoloface_8n.onnx";
10-
std::string face_landmarks_68 = "../../../examples/hub/onnx/cv/2dfan4.onnx";
11-
std::string face_recognizer_onnx_path = "../../../examples/hub/onnx/cv/arcface_w600k_r50.onnx";
12-
std::string face_restoration_onnx_path = "../../../examples/hub/onnx/cv/gfpgan_1.4.onnx";
8+
std::string face_swap_onnx_path = "/home/lite.ai.toolkit/examples/hub/onnx/cv/inswapper_128.onnx";
9+
std::string face_detect_onnx_path = "/home/lite.ai.toolkit/examples/hub/onnx/cv/yoloface_8n.onnx";
10+
std::string face_landmarks_68 = "/home/lite.ai.toolkit/examples/hub/onnx/cv/2dfan4.onnx";
11+
std::string face_recognizer_onnx_path = "/home/lite.ai.toolkit/examples/hub/onnx/cv/arcface_w600k_r50.onnx";
12+
std::string face_restoration_onnx_path = "/home/lite.ai.toolkit/examples/hub/onnx/cv/gfpgan_1.4.onnx";
1313

1414
auto pipeLine = lite::cv::face::swap::facefusion::PipeLine(
1515
face_detect_onnx_path,
@@ -19,27 +19,62 @@ static void test_default()
1919
face_restoration_onnx_path
2020
);
2121

22-
std::string source_image_path = "../../../examples/lite/resources/test_lite_facefusion_pipeline_source.jpg";
23-
std::string target_image_path = "../../../examples/lite/resources/test_lite_facefusion_pipeline_target.jpg";
24-
std::string save_image_path = "../../../examples/logs/test_lite_facefusion_pipeline_result.jpg";
22+
std::string source_image_path = "/home/lite.ai.toolkit/1.jpg";
23+
std::string target_image_path = "/home/lite.ai.toolkit/2.jpg";
24+
std::string save_image_path = "/home/lite.ai.toolkit/result111111.jpg";
2525

2626

2727
// 写一个测试时间的代码
2828
auto start = std::chrono::high_resolution_clock::now();
2929

30+
pipeLine.detect(source_image_path,target_image_path,save_image_path);
31+
auto end = std::chrono::high_resolution_clock::now();
32+
std::chrono::duration<double> diff = end-start;
33+
std::cout << "Time: " << diff.count() << " s\n";
3034

3135

36+
#endif
37+
}
38+
39+
40+
41+
42+
static void test_tensorrt()
43+
{
44+
#ifdef ENABLE_TENSORRT
45+
std::string face_swap_onnx_path = "../../../examples/hub/trt/inswapper_128_fp16.engine";
46+
std::string face_detect_onnx_path = "../../../examples/hub/trt/yoloface_8n_fp16.engine";
47+
std::string face_landmarks_68 = "../../../examples/hub/trt/2dfan4_fp16.engine";
48+
std::string face_recognizer_onnx_path = "../../../examples/hub/trt/arcface_w600k_r50_fp16.engine";
49+
std::string face_restoration_onnx_path = "../../../examples/hub/trt/gfpgan_1.4_fp32.engine";
50+
51+
auto pipeLine = lite::trt::cv::face::swap::FaceFusionPipeLine (
52+
face_detect_onnx_path,
53+
face_landmarks_68,
54+
face_recognizer_onnx_path,
55+
face_swap_onnx_path,
56+
face_restoration_onnx_path
57+
);
58+
59+
std::string source_image_path = "../../../examples/logs/1.jpg";
60+
std::string target_image_path = "../../../examples/logs/5.jpg";
61+
std::string save_image_path = "../../../examples/logs/trt_pipeline_result_cuda_test_13_mt.jpg";
62+
63+
64+
// 写一个测试时间的代码
65+
auto start = std::chrono::high_resolution_clock::now();
66+
3267
pipeLine.detect(source_image_path,target_image_path,save_image_path);
3368
auto end = std::chrono::high_resolution_clock::now();
3469
std::chrono::duration<double> diff = end-start;
35-
std::cout << "Time: " << diff.count() << " s\n";
70+
std::cout << "Time: " << diff.count() * 1000<< " ms\n";
3671

3772

3873
#endif
3974
}
4075

4176
int main()
4277
{
43-
44-
test_default();
78+
test_tensorrt();
79+
// test_default();
4580
}

lite/ort/cv/yolofacev8.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// Created by ai-test1 on 24-7-8.
2+
// Created by wangzijian on 24-7-8.
33
//
44

55
#include "yolofacev8.h"
@@ -9,6 +9,7 @@
99
using ortcv::YoloFaceV8;
1010

1111
float YoloFaceV8::get_iou(const lite::types::Boxf box1, const lite::types::Boxf box2) {
12+
// 左上角是坐标轴原点,右下角是坐标轴最大值
1213
float x1 = std::max(box1.x1, box2.x1);
1314
float y1 = std::max(box1.y1, box2.y1);
1415
float x2 = std::min(box1.x2, box2.x2);

lite/trt/cv/trt_face_68landmarks_mt.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ trt_face_68landmarks_mt::trt_face_68landmarks_mt(std::string &model_path, size_t
9999
worker_threads.emplace_back(&trt_face_68landmarks_mt::worker_function, this, i);
100100
}
101101

102+
affine_matrixs.resize(num_threads);
103+
img_with_landmarks_vec.resize(num_threads);
104+
102105
}
103106

104107
// 在cpp文件中修改相关实现
@@ -138,7 +141,7 @@ void trt_face_68landmarks_mt::worker_function(int thread_id) {
138141
}
139142

140143
void
141-
trt_face_68landmarks_mt::preprocess(const lite::types::Boxf &bounding_box, const cv::Mat &input_mat, cv::Mat &crop_img) {
144+
trt_face_68landmarks_mt::preprocess(const lite::types::Boxf &bounding_box, const cv::Mat &input_mat, cv::Mat &crop_img,int thread_id) {
142145
float xmin = bounding_box.x1;
143146
float ymin = bounding_box.y1;
144147
float xmax = bounding_box.x2;
@@ -159,7 +162,7 @@ trt_face_68landmarks_mt::preprocess(const lite::types::Boxf &bounding_box, const
159162

160163
cv::Size crop_size(256, 256);
161164

162-
std::tie(crop_img, affine_matrix) = face_utils::warp_face_by_translation(input_mat, translation, scale, crop_size);
165+
std::tie(crop_img, affine_matrixs[thread_id]) = face_utils::warp_face_by_translation(input_mat, translation, scale, crop_size);
163166

164167
crop_img.convertTo(crop_img,CV_32FC3,1 / 255.f);
165168
}
@@ -168,10 +171,10 @@ trt_face_68landmarks_mt::preprocess(const lite::types::Boxf &bounding_box, const
168171
void trt_face_68landmarks_mt::process_single_task( InferenceTask &task, int thread_id) {
169172
if (task.input_mat.empty()) return;
170173

171-
img_with_landmarks = task.input_mat.clone();
174+
img_with_landmarks_vec[thread_id] = task.input_mat.clone();
172175
cv::Mat crop_image;
173176

174-
preprocess(task.bbox, task.input_mat, crop_image);
177+
preprocess(task.bbox, task.input_mat, crop_image, thread_id);
175178

176179
std::vector<float> input_data;
177180

@@ -198,13 +201,13 @@ void trt_face_68landmarks_mt::process_single_task( InferenceTask &task, int thre
198201

199202
// 带出结果
200203
// 指针指向带出来
201-
*task.face_landmark_5of68 = postprocess(output.data());
204+
*task.face_landmark_5of68 = postprocess(output.data(),thread_id);
202205

203206
task.completion_promise.set_value();
204207
}
205208

206209

207-
std::vector<cv::Point2f> trt_face_68landmarks_mt::postprocess(float *trt_outputs) {
210+
std::vector<cv::Point2f> trt_face_68landmarks_mt::postprocess(float *trt_outputs,int thread_id) {
208211
std::vector<cv::Point2f> landmarks;
209212

210213
for (int i = 0;i < 68; ++i)
@@ -215,15 +218,15 @@ std::vector<cv::Point2f> trt_face_68landmarks_mt::postprocess(float *trt_outputs
215218
}
216219

217220
cv::Mat inverse_affine_matrix;
218-
cv::invertAffineTransform(affine_matrix, inverse_affine_matrix);
221+
cv::invertAffineTransform(affine_matrixs[thread_id], inverse_affine_matrix);
219222

220223
cv::transform(landmarks, landmarks, inverse_affine_matrix);
221224

222225
return face_utils::convert_face_landmark_68_to_5(landmarks);
223226
}
224227

225228

226-
void trt_face_68landmarks_mt::postprocess(float *trt_outputs, std::vector<cv::Point2f> &face_landmark_5of68) {
229+
void trt_face_68landmarks_mt::postprocess(float *trt_outputs, std::vector<cv::Point2f> &face_landmark_5of68,int thread_id) {
227230
std::vector<cv::Point2f> landmarks;
228231

229232
for (int i = 0;i < 68; ++i)
@@ -234,7 +237,7 @@ void trt_face_68landmarks_mt::postprocess(float *trt_outputs, std::vector<cv::Po
234237
}
235238

236239
cv::Mat inverse_affine_matrix;
237-
cv::invertAffineTransform(affine_matrix, inverse_affine_matrix);
240+
cv::invertAffineTransform(affine_matrixs[thread_id], inverse_affine_matrix);
238241

239242
cv::transform(landmarks, landmarks, inverse_affine_matrix);
240243

lite/trt/cv/trt_face_68landmarks_mt.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,18 @@ class trt_face_68landmarks_mt {
6565
// 实际的推理函数
6666
void process_single_task(InferenceTask& task, int thread_id);
6767

68-
void preprocess(const lite::types::Boxf &bouding_box,const cv::Mat &input_mat,cv::Mat &crop_img);
68+
void preprocess(const lite::types::Boxf &bouding_box,const cv::Mat &input_mat,cv::Mat &crop_img,int thread_id);
6969

70-
void postprocess(float *trt_outputs, std::vector<cv::Point2f> &face_landmark_5of68);
70+
void postprocess(float *trt_outputs, std::vector<cv::Point2f> &face_landmark_5of68,int thread_id);
7171

72-
std::vector<cv::Point2f> postprocess(float *trt_outputs);
72+
std::vector<cv::Point2f> postprocess(float *trt_outputs,int thread_id);
7373

7474

7575

7676
private:
77-
cv::Mat affine_matrix;
78-
cv::Mat img_with_landmarks;
77+
std::vector<cv::Mat> affine_matrixs;
78+
std::vector<cv::Mat> img_with_landmarks_vec;
79+
7980

8081
public:
8182
explicit trt_face_68landmarks_mt(std::string& model_path, size_t num_threads = 4);

lite/trt/cv/trt_face_recognizer.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,21 @@ void TRTFaceFusionFaceRecognizer::detect(cv::Mat &input_mat, std::vector<cv::Poi
5252
std::vector<float> normal_embeding(output.begin(),output.end());
5353

5454

55-
float norm = 0.0f;
56-
for (const auto &val : normal_embeding) {
57-
norm += val * val;
58-
}
59-
norm = std::sqrt(norm);
55+
launch_face_recognizer_postprocess(
56+
static_cast<float*>(buffers[1]),
57+
output_node_dims[0][0] * output_node_dims[0][1],
58+
output.data()
59+
);
60+
// float norm = 0.0f;
61+
// for (const auto &val : normal_embeding) {
62+
// norm += val * val;
63+
// }
64+
// norm = std::sqrt(norm);
65+
//
66+
// for (auto &val : normal_embeding) {
67+
// val /= norm;
68+
// }
6069

61-
for (auto &val : normal_embeding) {
62-
val /= norm;
63-
}
6470

6571
std::cout<<"done!"<<std::endl;
6672

lite/trt/cv/trt_face_recognizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "lite/trt/core/trt_core.h"
99
#include "lite/trt/core/trt_utils.h"
1010
#include "lite/trt/core/trt_types.h"
11+
#include "lite/trt/kernel/face_recognizer_postprocess_manager.h"
1112

1213
namespace trtcv{
1314
class LITE_EXPORTS TRTFaceFusionFaceRecognizer : BasicTRTHandler{

lite/trt/cv/trt_face_restoration.cpp

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,42 +11,53 @@ void TRTFaceFusionFaceRestoration::detect(cv::Mat &face_swap_image, std::vector<
1111

1212
cv::Mat crop_image;
1313
cv::Mat affine_matrix;
14-
std::tie(crop_image,affine_matrix) = face_utils::warp_face_by_face_landmark_5(face_swap_image,target_landmarks_5,face_utils::FFHQ_512);
14+
// 记录时间
15+
auto start_warp = std::chrono::high_resolution_clock::now();
16+
std::tie(crop_image,affine_matrix) = face_utils::warp_face_by_face_landmark_5(face_swap_image,target_landmarks_5,
17+
face_utils::FFHQ_512);
1518

1619
std::vector<float> crop_size = {512,512};
1720
cv::Mat box_mask = face_utils::create_static_box_mask(crop_size);
1821
std::vector<cv::Mat> crop_mask_list;
1922
crop_mask_list.emplace_back(box_mask);
2023

21-
cv::cvtColor(crop_image,crop_image,cv::COLOR_BGR2RGB);
22-
crop_image.convertTo(crop_image,CV_32FC3,1.f / 255.f);
23-
crop_image.convertTo(crop_image,CV_32FC3,2.0f,-1.f);
24+
cv::Mat crop_image_rgb;
25+
launch_bgr2rgb(crop_image,crop_image_rgb);
26+
crop_image_rgb.convertTo(crop_image_rgb,CV_32FC3,1.f / 255.f);
27+
crop_image_rgb.convertTo(crop_image_rgb,CV_32FC3,2.0f,-1.f);
2428

2529
std::vector<float> input_vector;
26-
trtcv::utils::transform::create_tensor(crop_image,input_vector,input_node_dims,trtcv::utils::transform::CHW);
30+
trtcv::utils::transform::create_tensor(crop_image_rgb,input_vector,input_node_dims,trtcv::utils::transform::CHW);
2731

28-
// 拷贝
32+
auto end_warp = std::chrono::high_resolution_clock::now();
33+
std::chrono::duration<double, std::milli> fp_ms_warp = end_warp - start_warp;
34+
std::cout << "FaceRestoration preprocess time: " << fp_ms_warp.count() << "ms" << std::endl;
2935

36+
37+
// 记录时间
38+
auto start = std::chrono::high_resolution_clock::now();
3039
// 先不用拷贝了 处理完成再拷贝出来 类似于整个后处理放在GPU上完成
3140
cudaMemcpyAsync(buffers[0],input_vector.data(),1 * 3 * 512 * 512 * sizeof(float),cudaMemcpyHostToDevice,stream);
32-
3341
// 同步
3442
cudaStreamSynchronize(stream);
35-
3643
// 推理
3744
bool status = trt_context->enqueueV3(stream);
45+
3846
if (!status) {
3947
std::cerr << "Failed to inference" << std::endl;
4048
return;
4149
}
42-
43-
4450
// 同步
4551
cudaStreamSynchronize(stream);
52+
auto end = std::chrono::high_resolution_clock::now();
53+
std::chrono::duration<double, std::milli> fp_ms = end - start;
54+
std::cout << "FaceRestoration Inference time: " << fp_ms.count() << "ms" << std::endl;
4655
std::vector<unsigned char> transposed_data(1 * 3 * 512 * 512);
4756

4857
// std::vector<float> transposed_data(1 * 3 * 512 * 512);
4958

59+
// 记录时间
60+
auto start_postprocess = std::chrono::high_resolution_clock::now();
5061
// 这里buffer1就是输出了
5162
launch_face_restoration_postprocess(
5263
static_cast<float*>(buffers[1]),
@@ -64,47 +75,31 @@ void TRTFaceFusionFaceRestoration::detect(cv::Mat &face_swap_image, std::vector<
6475
std::vector<float> output_vector(1 * 3 * 512 * 512);
6576
// cudaMemcpyAsync(output_vector.data(),buffers[1],1 * 3 * 512 * 512 * sizeof(float),cudaMemcpyDeviceToHost,stream);
6677
cudaStreamSynchronize(stream);
67-
//
6878
// 后处理
6979
int channel = 3;
7080
int height = 512;
7181
int width = 512;
72-
// std::vector<float> output(channel * height * width);
73-
// output.assign(output_vector.begin(),output_vector.end());
74-
//
75-
// std::transform(output.begin(),output.end(),output.begin(),
76-
// [](double x){return std::max(-1.0,std::max(-1.0,std::min(1.0,x)));});
77-
//
78-
// std::transform(output.begin(),output.end(),output.begin(),
79-
// [](double x){return (x + 1.f) /2.f;});
80-
//
81-
// // CHW2HWC
82-
// for (int c = 0; c < channel; ++c){
83-
// for (int h = 0 ; h < height; ++h){
84-
// for (int w = 0; w < width ; ++w){
85-
// int src_index = c * (height * width) + h * width + w;
86-
// int dst_index = h * (width * channel) + w * channel + c;
87-
// transposed_data[dst_index] = output[src_index];
88-
// }
89-
// }
90-
// }
91-
//
92-
// std::transform(transposed_data.begin(),transposed_data.end(),transposed_data.begin(),
93-
// [](float x){return std::round(x * 255.f);});
94-
//
95-
// std::transform(transposed_data.begin(), transposed_data.end(), transposed_data.begin(),
96-
// [](float x) { return static_cast<uint8_t>(x); });
9782

9883

9984
cv::Mat mat(height, width, CV_32FC3, transposed_data_float.data());
100-
// cv::imwrite("/home/lite.ai.toolkit/mid_process.jpg",mat);
10185
cv::cvtColor(mat, mat, cv::COLOR_RGB2BGR);
86+
// 到这里为止基本不耗时
10287

10388

10489
auto crop_mask = crop_mask_list[0];
105-
cv::Mat paste_frame = face_utils::paste_back(ori_image,mat,crop_mask,affine_matrix);
106-
90+
// 这里的paste_back 40ms左右
91+
cv::Mat paste_frame = launch_paste_back(ori_image,mat,crop_mask,affine_matrix);
92+
// cv::Mat paste_frame = face_utils::paste_back(ori_image,mat,crop_mask,affine_matrix);
10793
cv::Mat dst_image = face_utils::blend_frame(ori_image,paste_frame);
94+
auto end_postprocess = std::chrono::high_resolution_clock::now();
95+
std::chrono::duration<double, std::milli> fp_ms_postprocess = end_postprocess - start_postprocess;
96+
std::cout << "FaceRestoration postprocess time: " << fp_ms_postprocess.count() << "ms" << std::endl;
10897

98+
// 记录时间
99+
auto start_save = std::chrono::high_resolution_clock::now();
109100
cv::imwrite(face_enchaner_path,dst_image);
101+
auto end_save = std::chrono::high_resolution_clock::now();
102+
std::chrono::duration<double, std::milli> fp_ms_save = end_save - start_save;
103+
std::cout << "FaceRestoration save time: " << fp_ms_save.count() << "ms" << std::endl;
104+
110105
}

lite/trt/cv/trt_face_restoration.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
#include "lite/trt/core/trt_config.h"
1010
#include "lite/ort/cv/face_utils.h"
1111
#include "lite/trt/kernel/face_restoration_postprocess_manager.h"
12+
#include "lite/trt/kernel/bgr2rgb_manager.h"
13+
#include "lite/trt/kernel/paste_back_manager.h"
1214
namespace trtcv{
1315
class LITE_EXPORTS TRTFaceFusionFaceRestoration : BasicTRTHandler{
1416
public:
1517
explicit TRTFaceFusionFaceRestoration(const std::string& _trt_model_path,unsigned int _num_threads = 1) :
1618
BasicTRTHandler(_trt_model_path,_num_threads){};;
1719
public:
20+
// 这个是直接保存的
1821
void detect(cv::Mat &face_swap_image,std::vector<cv::Point2f > &target_landmarks_5 ,const std::string &face_enchaner_path);
1922

2023
};

0 commit comments

Comments
 (0)