1
+ //
2
+ // Created by wangzijian on 10/28/24.
3
+ //
4
+
5
+ #include " trt_modnet.h"
6
+ using trtcv::TRTMODNet;
7
+
8
+ void TRTMODNet::preprocess (cv::Mat &input_mat) {
9
+ cv::Mat ori_input_mat = input_mat;
10
+ cv::resize (input_mat,input_mat,cv::Size (512 ,512 ));
11
+ cv::cvtColor (input_mat,input_mat,cv::COLOR_BGR2RGB);
12
+ if (input_mat.type () != CV_32FC3) input_mat.convertTo (input_mat, CV_32FC3);
13
+ else input_mat = input_mat;
14
+ input_mat = (input_mat -mean_val) * scale_val;
15
+
16
+ }
17
+
18
+
19
+
20
+ void TRTMODNet::detect (const cv::Mat &mat, types::MattingContent &content, bool remove_noise, bool minimum_post_process) {
21
+ if (mat.empty ()) return ;
22
+ cv::Mat preprocessed_mat = mat;
23
+ preprocess (preprocessed_mat);
24
+
25
+ const int batch_size = 1 ;
26
+ const int channels = 3 ;
27
+ const int input_h = preprocessed_mat.rows ;
28
+ const int input_w = preprocessed_mat.cols ;
29
+ const size_t input_size = batch_size * channels * input_h * input_w * sizeof (float );
30
+ const size_t output_size = batch_size * channels * input_h * input_w * sizeof (float );
31
+
32
+ for (auto & buffer : buffers) {
33
+ if (buffer) {
34
+ cudaFree (buffer);
35
+ buffer = nullptr ;
36
+ }
37
+ }
38
+ cudaMalloc (&buffers[0 ], input_size);
39
+ cudaMalloc (&buffers[1 ], output_size);
40
+ if (!buffers[0 ] || !buffers[1 ]) {
41
+ std::cerr << " Failed to allocate CUDA memory" << std::endl;
42
+ return ;
43
+ }
44
+
45
+ input_node_dims = {batch_size, channels, input_h, input_w};
46
+
47
+ std::vector<float > input;
48
+ trtcv::utils::transform::create_tensor (preprocessed_mat,input,input_node_dims,trtcv::utils::transform::CHW);
49
+
50
+ // 3.infer
51
+ cudaMemcpyAsync (buffers[0 ], input.data (), input_size,
52
+ cudaMemcpyHostToDevice, stream);
53
+
54
+ nvinfer1::Dims MODNetDims;
55
+ MODNetDims.nbDims = 4 ;
56
+ MODNetDims.d [0 ] = batch_size;
57
+ MODNetDims.d [1 ] = channels;
58
+ MODNetDims.d [2 ] = input_h;
59
+ MODNetDims.d [3 ] = input_w;
60
+
61
+ auto input_tensor_name = trt_engine->getIOTensorName (0 );
62
+ auto output_tensor_name = trt_engine->getIOTensorName (1 );
63
+ trt_context->setTensorAddress (input_tensor_name, buffers[0 ]);
64
+ trt_context->setTensorAddress (output_tensor_name, buffers[1 ]);
65
+ trt_context->setInputShape (input_tensor_name, MODNetDims);
66
+
67
+ bool status = trt_context->enqueueV3 (stream);
68
+ if (!status){
69
+ std::cerr << " Failed to infer by TensorRT." << std::endl;
70
+ return ;
71
+ }
72
+
73
+
74
+
75
+ std::vector<float > output (batch_size * channels * input_h * input_w);
76
+ cudaMemcpyAsync (output.data (), buffers[1 ], output_size,
77
+ cudaMemcpyDeviceToHost, stream);
78
+
79
+ // post
80
+ generate_matting (output.data (),mat,content, remove_noise, minimum_post_process);
81
+ }
82
+
83
+ void TRTMODNet::generate_matting (float *trt_outputs, const cv::Mat &mat, types::MattingContent &content,
84
+ bool remove_noise, bool minimum_post_process) {
85
+
86
+ const unsigned int h = mat.rows ;
87
+ const unsigned int w = mat.cols ;
88
+
89
+
90
+ const unsigned int out_h = 512 ;
91
+ const unsigned int out_w = 512 ;
92
+
93
+ cv::Mat alpha_pred (out_h, out_w, CV_32FC1, trt_outputs);
94
+ cv::imwrite (" /home/lite.ai.toolkit/modnet.jpg" ,alpha_pred);
95
+ // post process
96
+ if (remove_noise) trtcv::utils::remove_small_connected_area (alpha_pred,0 .05f );
97
+ // resize alpha
98
+ if (out_h != h || out_w != w)
99
+ // already allocated a new continuous memory after resize.
100
+ cv::resize (alpha_pred, alpha_pred, cv::Size (w, h));
101
+ // need clone to allocate a new continuous memory if not performed resize.
102
+ // The memory elements point to will release after return.
103
+ else alpha_pred = alpha_pred.clone ();
104
+
105
+ cv::Mat pmat = alpha_pred; // ref
106
+ content.pha_mat = pmat; // auto handle the memory inside ocv with smart ref.
107
+
108
+ if (!minimum_post_process)
109
+ {
110
+ // MODNet only predict Alpha, no fgr. So,
111
+ // the fake fgr and merge mat may not need,
112
+ // let the fgr mat and merge mat empty to
113
+ // Speed up the post processes.
114
+ cv::Mat mat_copy;
115
+ mat.convertTo (mat_copy, CV_32FC3);
116
+ // merge mat and fgr mat may not need
117
+ std::vector<cv::Mat> mat_channels;
118
+ cv::split (mat_copy, mat_channels);
119
+ cv::Mat bmat = mat_channels.at (0 );
120
+ cv::Mat gmat = mat_channels.at (1 );
121
+ cv::Mat rmat = mat_channels.at (2 ); // ref only, zero-copy.
122
+ bmat = bmat.mul (pmat);
123
+ gmat = gmat.mul (pmat);
124
+ rmat = rmat.mul (pmat);
125
+ cv::Mat rest = 1 .f - pmat;
126
+ cv::Mat mbmat = bmat.mul (pmat) + rest * 153 .f ;
127
+ cv::Mat mgmat = gmat.mul (pmat) + rest * 255 .f ;
128
+ cv::Mat mrmat = rmat.mul (pmat) + rest * 120 .f ;
129
+ std::vector<cv::Mat> fgr_channel_mats, merge_channel_mats;
130
+ fgr_channel_mats.push_back (bmat);
131
+ fgr_channel_mats.push_back (gmat);
132
+ fgr_channel_mats.push_back (rmat);
133
+ merge_channel_mats.push_back (mbmat);
134
+ merge_channel_mats.push_back (mgmat);
135
+ merge_channel_mats.push_back (mrmat);
136
+
137
+ cv::merge (fgr_channel_mats, content.fgr_mat );
138
+ cv::merge (merge_channel_mats, content.merge_mat );
139
+
140
+ content.fgr_mat .convertTo (content.fgr_mat , CV_8UC3);
141
+ content.merge_mat .convertTo (content.merge_mat , CV_8UC3);
142
+ }
143
+
144
+ content.flag = true ;
145
+
146
+ }
0 commit comments