Skip to content

Commit 4252d27

Browse files
[TRT] support MODNet (#442)
* add modnet trt test code * modnet trt implement * update code * add trt modnet
1 parent 557521d commit 4252d27

File tree

6 files changed

+284
-5
lines changed

6 files changed

+284
-5
lines changed

examples/lite/cv/test_lite_modnet.cpp

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,55 @@ static void test_onnxruntime()
9494
#endif
9595
}
9696

97+
98+
99+
static void test_tensorrt()
100+
{
101+
#ifdef ENABLE_TENSORRT
102+
std::string engine_path = "../../../examples/hub/trt/modnet_fp16.engine";
103+
std::string test_img_path = "../../../examples/lite/resources/test_lite_matting_input.jpg";
104+
std::string test_bgr_path = "../../../examples/lite/resources/test_lite_matting_bgr.jpg";
105+
std::string save_fgr_path = "../../../examples/logs/test_lite_modnet_fgr_trt.jpg";
106+
std::string save_pha_path = "../../../examples/logs/test_lite_modnet_pha_trt.jpg";
107+
std::string save_merge_path = "../../../examples/logs/test_lite_modnet_merge_trt.jpg";
108+
std::string save_swap_path = "../../../examples/logs/test_lite_modnet_swap_trt.jpg";
109+
110+
111+
lite::trt::cv::matting::MODNet *modnet = new lite::trt::cv::matting::MODNet (engine_path);
112+
113+
lite::types::MattingContent content;
114+
cv::Mat img_bgr = cv::imread(test_img_path);
115+
cv::Mat bgr_mat = cv::imread(test_bgr_path);
116+
117+
// 1. image matting.
118+
modnet->detect(img_bgr, content, true, true);
119+
120+
if (content.flag)
121+
{
122+
if (!content.fgr_mat.empty()) cv::imwrite(save_fgr_path, content.fgr_mat);
123+
if (!content.pha_mat.empty()) cv::imwrite(save_pha_path, content.pha_mat * 255.);
124+
if (!content.merge_mat.empty()) cv::imwrite(save_merge_path, content.merge_mat);
125+
// swap background
126+
cv::Mat out_mat;
127+
if (!content.fgr_mat.empty())
128+
lite::utils::swap_background(content.fgr_mat, content.pha_mat, bgr_mat, out_mat, true);
129+
else
130+
lite::utils::swap_background(img_bgr, content.pha_mat, bgr_mat, out_mat, false);
131+
132+
if (!out_mat.empty())
133+
{
134+
cv::imwrite(save_swap_path, out_mat);
135+
std::cout << "Saved Swap Image Done!" << std::endl;
136+
}
137+
138+
std::cout << "Default Version MGMatting Done!" << std::endl;
139+
}
140+
141+
delete modnet;
142+
#endif
143+
}
144+
145+
97146
static void test_mnn()
98147
{
99148
#ifdef ENABLE_MNN
@@ -233,11 +282,12 @@ static void test_tnn()
233282

234283
static void test_lite()
235284
{
236-
test_default();
237-
test_onnxruntime();
238-
test_mnn();
239-
test_ncnn();
240-
test_tnn();
285+
// test_default();
286+
// test_onnxruntime();
287+
// test_mnn();
288+
// test_ncnn();
289+
// test_tnn();
290+
test_tensorrt();
241291
}
242292

243293
int main(__unused int argc, __unused char *argv[])

lite/models.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
#include "lite/trt/cv/trt_yolox.h"
133133
#include "lite/trt/cv/trt_yolov8.h"
134134
#include "lite/trt/cv/trt_yolov6.h"
135+
#include "lite/trt/cv/trt_modnet.h"
135136
#include "lite/trt/cv/trt_yolov5_blazeface.h"
136137
#include "lite/trt/cv/trt_lightenhance.h"
137138
#include "lite/trt/cv/trt_realesrgan.h"
@@ -731,9 +732,14 @@ namespace lite{
731732
typedef trtcv::TRTYOLO5Face _TRT_YOLO5Face;
732733
typedef trtcv::TRTLightEnhance _TRT_LightEnhance;
733734
typedef trtcv::TRTRealESRGAN _TRT_RealESRGAN;
735+
typedef trtcv::TRTMODNet _TRT_MODNet;
734736
namespace classification
735737
{
736738

739+
}
740+
namespace matting
741+
{
742+
typedef _TRT_MODNet MODNet;
737743
}
738744
namespace detection
739745
{

lite/trt/core/trt_utils.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,46 @@ void trtcv::utils::transform::trt_generate_latents(std::vector<float> &latents,
8383
for (size_t i = 0; i < total_size; ++i) {
8484
latents[i] = dist(gen) * init_noise_sigma;
8585
}
86+
}
87+
88+
void trtcv::utils::remove_small_connected_area(cv::Mat &alpha_pred, float threshold) {
89+
cv::Mat gray, binary;
90+
alpha_pred.convertTo(gray, CV_8UC1, 255.f);
91+
// 255 * 0.05 ~ 13
92+
unsigned int binary_threshold = (unsigned int) (255.f * threshold);
93+
// https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/util.py#L209
94+
cv::threshold(gray, binary, binary_threshold, 255, cv::THRESH_BINARY);
95+
// morphologyEx with OPEN operation to remove noise first.
96+
auto kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(3, 3), cv::Point(-1, -1));
97+
cv::morphologyEx(binary, binary, cv::MORPH_OPEN, kernel);
98+
// Computationally connected domain
99+
cv::Mat labels = cv::Mat::zeros(alpha_pred.size(), CV_32S);
100+
cv::Mat stats, centroids;
101+
int num_labels = cv::connectedComponentsWithStats(binary, labels, stats, centroids, 8, 4);
102+
if (num_labels <= 1) return; // no noise, skip.
103+
// find max connected area, 0 is background
104+
int max_connected_id = 1; // 1,2,...
105+
int max_connected_area = stats.at<int>(max_connected_id, cv::CC_STAT_AREA);
106+
for (int i = 1; i < num_labels; ++i)
107+
{
108+
int tmp_connected_area = stats.at<int>(i, cv::CC_STAT_AREA);
109+
if (tmp_connected_area > max_connected_area)
110+
{
111+
max_connected_area = tmp_connected_area;
112+
max_connected_id = i;
113+
}
114+
}
115+
const int h = alpha_pred.rows;
116+
const int w = alpha_pred.cols;
117+
// remove small connected area.
118+
for (int i = 0; i < h; ++i)
119+
{
120+
int *label_row_ptr = labels.ptr<int>(i);
121+
float *alpha_row_ptr = alpha_pred.ptr<float>(i);
122+
for (int j = 0; j < w; ++j)
123+
{
124+
if (label_row_ptr[j] != max_connected_id)
125+
alpha_row_ptr[j] = 0.f;
126+
}
127+
}
86128
}

lite/trt/core/trt_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace trtcv
2727

2828
LITE_EXPORTS void trt_generate_latents(std::vector<float>& latents, int batch_size, int unet_channels, int latent_height, int latent_width, float init_noise_sigma);
2929
}
30+
LITE_EXPORTS void remove_small_connected_area(cv::Mat &alpha_pred, float threshold);
3031
}
3132
}
3233

lite/trt/cv/trt_modnet.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
//
2+
// Created by wangzijian on 10/28/24.
3+
//
4+
5+
#include "trt_modnet.h"
6+
using trtcv::TRTMODNet;
7+
8+
void TRTMODNet::preprocess(cv::Mat &input_mat) {
9+
cv::Mat ori_input_mat = input_mat;
10+
cv::resize(input_mat,input_mat,cv::Size(512,512));
11+
cv::cvtColor(input_mat,input_mat,cv::COLOR_BGR2RGB);
12+
if (input_mat.type() != CV_32FC3) input_mat.convertTo(input_mat, CV_32FC3);
13+
else input_mat = input_mat;
14+
input_mat = (input_mat -mean_val) * scale_val;
15+
16+
}
17+
18+
19+
20+
void TRTMODNet::detect(const cv::Mat &mat, types::MattingContent &content, bool remove_noise, bool minimum_post_process) {
21+
if (mat.empty()) return;
22+
cv::Mat preprocessed_mat = mat;
23+
preprocess(preprocessed_mat);
24+
25+
const int batch_size = 1;
26+
const int channels = 3;
27+
const int input_h = preprocessed_mat.rows;
28+
const int input_w = preprocessed_mat.cols;
29+
const size_t input_size = batch_size * channels * input_h * input_w * sizeof(float);
30+
const size_t output_size = batch_size * channels * input_h * input_w * sizeof(float);
31+
32+
for (auto& buffer : buffers) {
33+
if (buffer) {
34+
cudaFree(buffer);
35+
buffer = nullptr;
36+
}
37+
}
38+
cudaMalloc(&buffers[0], input_size);
39+
cudaMalloc(&buffers[1], output_size);
40+
if (!buffers[0] || !buffers[1]) {
41+
std::cerr << "Failed to allocate CUDA memory" << std::endl;
42+
return;
43+
}
44+
45+
input_node_dims = {batch_size, channels, input_h, input_w};
46+
47+
std::vector<float> input;
48+
trtcv::utils::transform::create_tensor(preprocessed_mat,input,input_node_dims,trtcv::utils::transform::CHW);
49+
50+
//3.infer
51+
cudaMemcpyAsync(buffers[0], input.data(), input_size,
52+
cudaMemcpyHostToDevice, stream);
53+
54+
nvinfer1::Dims MODNetDims;
55+
MODNetDims.nbDims = 4;
56+
MODNetDims.d[0] = batch_size;
57+
MODNetDims.d[1] = channels;
58+
MODNetDims.d[2] = input_h;
59+
MODNetDims.d[3] = input_w;
60+
61+
auto input_tensor_name = trt_engine->getIOTensorName(0);
62+
auto output_tensor_name = trt_engine->getIOTensorName(1);
63+
trt_context->setTensorAddress(input_tensor_name, buffers[0]);
64+
trt_context->setTensorAddress(output_tensor_name, buffers[1]);
65+
trt_context->setInputShape(input_tensor_name, MODNetDims);
66+
67+
bool status = trt_context->enqueueV3(stream);
68+
if (!status){
69+
std::cerr << "Failed to infer by TensorRT." << std::endl;
70+
return;
71+
}
72+
73+
74+
75+
std::vector<float> output(batch_size * channels * input_h * input_w);
76+
cudaMemcpyAsync(output.data(), buffers[1], output_size,
77+
cudaMemcpyDeviceToHost, stream);
78+
79+
// post
80+
generate_matting(output.data(),mat,content, remove_noise, minimum_post_process);
81+
}
82+
83+
void TRTMODNet::generate_matting(float *trt_outputs, const cv::Mat &mat, types::MattingContent &content,
84+
bool remove_noise, bool minimum_post_process) {
85+
86+
const unsigned int h = mat.rows;
87+
const unsigned int w = mat.cols;
88+
89+
90+
const unsigned int out_h = 512;
91+
const unsigned int out_w = 512;
92+
93+
cv::Mat alpha_pred(out_h, out_w, CV_32FC1, trt_outputs);
94+
cv::imwrite("/home/lite.ai.toolkit/modnet.jpg",alpha_pred);
95+
// post process
96+
if (remove_noise) trtcv::utils::remove_small_connected_area(alpha_pred,0.05f);
97+
// resize alpha
98+
if (out_h != h || out_w != w)
99+
// already allocated a new continuous memory after resize.
100+
cv::resize(alpha_pred, alpha_pred, cv::Size(w, h));
101+
// need clone to allocate a new continuous memory if not performed resize.
102+
// The memory elements point to will release after return.
103+
else alpha_pred = alpha_pred.clone();
104+
105+
cv::Mat pmat = alpha_pred; // ref
106+
content.pha_mat = pmat; // auto handle the memory inside ocv with smart ref.
107+
108+
if (!minimum_post_process)
109+
{
110+
// MODNet only predict Alpha, no fgr. So,
111+
// the fake fgr and merge mat may not need,
112+
// let the fgr mat and merge mat empty to
113+
// Speed up the post processes.
114+
cv::Mat mat_copy;
115+
mat.convertTo(mat_copy, CV_32FC3);
116+
// merge mat and fgr mat may not need
117+
std::vector<cv::Mat> mat_channels;
118+
cv::split(mat_copy, mat_channels);
119+
cv::Mat bmat = mat_channels.at(0);
120+
cv::Mat gmat = mat_channels.at(1);
121+
cv::Mat rmat = mat_channels.at(2); // ref only, zero-copy.
122+
bmat = bmat.mul(pmat);
123+
gmat = gmat.mul(pmat);
124+
rmat = rmat.mul(pmat);
125+
cv::Mat rest = 1.f - pmat;
126+
cv::Mat mbmat = bmat.mul(pmat) + rest * 153.f;
127+
cv::Mat mgmat = gmat.mul(pmat) + rest * 255.f;
128+
cv::Mat mrmat = rmat.mul(pmat) + rest * 120.f;
129+
std::vector<cv::Mat> fgr_channel_mats, merge_channel_mats;
130+
fgr_channel_mats.push_back(bmat);
131+
fgr_channel_mats.push_back(gmat);
132+
fgr_channel_mats.push_back(rmat);
133+
merge_channel_mats.push_back(mbmat);
134+
merge_channel_mats.push_back(mgmat);
135+
merge_channel_mats.push_back(mrmat);
136+
137+
cv::merge(fgr_channel_mats, content.fgr_mat);
138+
cv::merge(merge_channel_mats, content.merge_mat);
139+
140+
content.fgr_mat.convertTo(content.fgr_mat, CV_8UC3);
141+
content.merge_mat.convertTo(content.merge_mat, CV_8UC3);
142+
}
143+
144+
content.flag = true;
145+
146+
}

lite/trt/cv/trt_modnet.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//
2+
// Created by wangzijian on 10/28/24.
3+
//
4+
5+
#ifndef LITE_AI_TOOLKIT_TRT_MODNET_H
6+
#define LITE_AI_TOOLKIT_TRT_MODNET_H
7+
8+
#include "lite/trt/core/trt_core.h"
9+
#include "lite/trt/core/trt_utils.h"
10+
11+
namespace trtcv{
12+
class LITE_EXPORTS TRTMODNet : public BasicTRTHandler{
13+
public:
14+
explicit TRTMODNet(const std::string& _trt_model_path,unsigned int _num_threads = 1):
15+
BasicTRTHandler(_trt_model_path, _num_threads)
16+
{};
17+
private:
18+
static constexpr const float mean_val = 127.5f; // RGB
19+
static constexpr const float scale_val = 1.f / 127.5f;
20+
private:
21+
void preprocess(cv::Mat &input_mat);
22+
23+
void generate_matting(float *trt_outputs,
24+
const cv::Mat &mat, types::MattingContent &content,
25+
bool remove_noise = false, bool minimum_post_process = false);
26+
public:
27+
void detect(const cv::Mat &mat, types::MattingContent &content, bool remove_noise = false,
28+
bool minimum_post_process = false);
29+
};
30+
}
31+
32+
33+
34+
#endif //LITE_AI_TOOLKIT_TRT_MODNET_H

0 commit comments

Comments
 (0)