forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
onnx2trt_utils.hpp
379 lines (294 loc) · 18.6 KB
/
onnx2trt_utils.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "ShapedWeights.hpp"
#include "ShapeTensor.hpp"
#include "Status.hpp"
#include "trt_utils.hpp"
#include <NvInfer.h>
#include <onnx/onnx_pb.h>
#include <cstring> // For std::memcpy
#include <iostream>
#include <numeric>
#include <sstream>
#include <limits>
#define LOG(msg, severity) \
do \
{ \
std::stringstream ss{}; \
if (severity <= nvinfer1::ILogger::Severity::kWARNING) ss << __FILENAME__ << ":" << __LINE__ << ": "; \
ss << msg; \
ctx->logger().log(severity, ss.str().c_str()); \
} while (0)
#define LOG_VERBOSE(msg) LOG(msg, nvinfer1::ILogger::Severity::kVERBOSE)
#define LOG_INFO(msg) LOG(msg, nvinfer1::ILogger::Severity::kINFO)
#define LOG_WARNING(msg) LOG(msg, nvinfer1::ILogger::Severity::kWARNING)
#define LOG_ERROR(msg) LOG(msg, nvinfer1::ILogger::Severity::kERROR)
// Overloads of operator<< on TensorRT types must be defined inside nvinfer1
// so that argument-dependent lookup works as expected. Declared static to
// avoid symbol clashing when statically linking with other TensorRT libraries
namespace nvinfer1
{
template <typename T>
static std::ostream& printSequence(std::ostream& stream, const T* begin, int count)
{
stream << "(";
if (count > 0)
{
std::copy_n(begin, count - 1, std::ostream_iterator<T>(stream, ", "));
stream << begin[count - 1];
}
stream << ")";
return stream;
}
static std::ostream& operator<<(std::ostream& stream, const nvinfer1::Dims& shape)
{
return printSequence(stream, shape.d, shape.nbDims);
}
static std::ostream& operator<<(std::ostream& stream, const nvinfer1::Permutation& perm)
{
return printSequence(stream, perm.order, nvinfer1::Dims::MAX_DIMS);
}
static std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType& dtype)
{
switch (dtype)
{
case nvinfer1::DataType::kFLOAT: return stream << "float32";
case nvinfer1::DataType::kHALF: return stream << "float16";
case nvinfer1::DataType::kINT8: return stream << "int8";
case nvinfer1::DataType::kINT32: return stream << "int32";
case nvinfer1::DataType::kBOOL: return stream << "bool";
default: throw std::runtime_error("Unknown dtype");
}
}
} // namespace nvinfer1
namespace onnx2trt
{
struct PluginDeleter
{
void operator()(nvinfer1::IPluginV2* t);
};
// Helper function to calculate the volume of a Dims object
int64_t volume(const nvinfer1::Dims& dims);
// Helper function to get the size in bytes of an ONNX datatype
int getDtypeSize(int32_t onnxDtype);
// Helper function to add a scalar into TRT through a constant layer.
template <typename ScalarType>
inline nvinfer1::IConstantLayer* addConstantScalar(
IImporterContext* ctx, ScalarType scalar, ShapedWeights::DataType type, nvinfer1::Dims shape = nvinfer1::Dims{0})
{
assert(volume(shape) == 1 && "Cannot add constant scalar with a shape that has volume > 1");
ShapedWeights scalarWeights = ctx->createTempWeights(type, shape);
static_cast<ScalarType*>(scalarWeights.values)[0] = static_cast<ScalarType>(scalar);
return ctx->network()->addConstant(scalarWeights.shape, scalarWeights);
}
// Helper function to create a tensor given a vector of values and a shape.
template <typename ScalarType>
inline nvinfer1::IConstantLayer* addConstant(
IImporterContext* ctx, const std::vector<ScalarType>& values, ShapedWeights::DataType type, nvinfer1::Dims shape)
{
assert(volume(shape) == static_cast<int64_t>(values.size()) && "Shape does not match number of values provided");
assert(sizeof(ScalarType) == getDtypeSize(type) && "ONNX dtype does not have the same size as the value type");
ShapedWeights weights = ctx->createTempWeights(type, shape);
std::memcpy(weights.values, values.data(), values.size() * sizeof(ScalarType));
return ctx->network()->addConstant(weights.shape, weights);
}
enum ScaleOp
{
kSHIFT,
kSCALE,
kPOWER,
};
// Helper function to import ONNX activation nodes into TRT
NodeImportResult activationHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::ActivationType op, float* alpha = nullptr, float* beta = nullptr);
// Add clipping to a tensor if clip is a valid value.
nvinfer1::ITensor* addClip(IImporterContext* ctx, nvinfer1::ITensor* input, float clip);
// Helper function to import ArgMax and ArgMin nodes into TRT
NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::TopKOperation op);
//! If t has rank less than nbDims, reshape it to have nbDims by prepending ones to its dimensions.
//! Assert failure if t has rank greater than nbDims.
Status broadcastTensor(IImporterContext* ctx, nvinfer1::ITensor*& t, const int nbDims);
// Helper function to broadcast two tensors to the larger one's shape
Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2);
// Helper function to broadcast three tensors to the largest one's shape
Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2, nvinfer1::ITensor*& t3);
// Helper funtion to check that two shapes conform to the broadcasting rules
Status isBroadcastValid(IImporterContext* ctx, const nvinfer1::Dims& firstShape, const nvinfer1::Dims& secondShape);
// Helper function to calculate the bias tensor for GatherElements.
std::vector<int32_t> calculateBias(
const nvinfer1::Dims& daDims, const nvinfer1::Dims& idxDims, const std::vector<int32_t>& pitches, int32_t axis);
// Helper function to calculate and return a vector representation of the pitches of a given shape
std::vector<int32_t> calculatePitches(const nvinfer1::Dims& inputDims);
// Helper function to check that linear resize can be used
bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors);
// Helper function to add a Cast layer in the network
nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype);
// Helper function for constantOfShape operator. Input shape must be a shape tensor
nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
nvinfer1::ITensor* constant, nvinfer1::ITensor* shape);
// Helper function to convert an ONNX axis into a TRT axis
Status convertAxis(int& axis, int nbDims);
// Helper function to convert an ONNX datatype into a TRT datatype
bool convertDtype(int32_t onnx_dtype, nvinfer1::DataType* trt_dtype);
// Helper function to convert INT64 weight values into INT32
int32_t* convertINT64(const int64_t* weightValues, nvinfer1::Dims shape, IImporterContext* ctx);
// Helper function to convert ONNX padding into TRT padding. Will update startTensor and totalPaddingTensor by reference
bool convertOnnxPadding(IImporterContext* ctx, int32_t nbInputDims, const std::vector<int32_t>& onnxPadding,
nvinfer1::ITensor*& startTensor, nvinfer1::ITensor*& totalPaddingTensor);
// Helper function to check if all of the values in the shift tensor are zeros
bool shiftIsAllZeros(const ShapedWeights& shiftInt8);
// Helper function to create zero shifts for QuantizeLinear/DequantizeLinear ops
onnx2trt::ShapedWeights createZeroShifts(const onnx2trt::ShapedWeights& shiftInt8, int32_t type, IImporterContext* ctx);
// Helper function to create a tensor of all zeros with the same shape as a data tensor
nvinfer1::ITensor* createZeroTensor(IImporterContext* ctx, nvinfer1::ITensor* data);
// Helper function to convert an ONNX weight into a ShapedWeights object
bool convertOnnxWeights(
const ::ONNX_NAMESPACE::TensorProto& onnxTensor, onnx2trt::ShapedWeights* weights, IImporterContext* ctx);
// Helper function to convert multi input convolution
NodeImportResult convMultiInput(
IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, std::vector<TensorOrWeights>& inputs);
// Helper function to convert a 1D tensor into a scalar
nvinfer1::ITensor* convertToScalar(IImporterContext* ctx, nvinfer1::ITensor* inpTensor);
// Helper function to convert a ShapedWeights object into a tensor
nvinfer1::ITensor& convertToTensor(TensorOrWeights& input, IImporterContext* ctx);
// Helper function to convert a ShapedWeights object into a scalar
nvinfer1::ITensor* convertToScalar(TensorOrWeights& input, IImporterContext* ctx);
// Helper function to provide a ceiling-rounding division between two integers
int divCeil(int n, int d);
// Helper function to check that the input data types for an elementwise operation are supported
bool elementwiseCheck(const std::vector<TensorOrWeights>& inputs, const nvinfer1::ElementWiseOperation op);
// Helper function to import an ONNX elementwise op into TRT
NodeImportResult elementwiseHelper(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node,
const std::vector<TensorOrWeights>& inputs, nvinfer1::ElementWiseOperation binary_op);
// Helper function to flatten a tensor on a given axis
nvinfer1::ITensor* flattenTensor(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, nvinfer1::ITensor& tensor, int axis = 0, bool regLayer = false);
// Gathers the specified dimension from a shape tensor. e.g. gatherDimension(shape=(7, 6, 5), dim=2) would return 5.
// shape specifies the shape of the returned Tensor. Must have a volume of 1.
nvinfer1::ITensor* gatherDimension(
IImporterContext* ctx, nvinfer1::ITensor* shapeTensor, int dim, nvinfer1::Dims shape);
// Helper function to generate padding values for convTranspose
void generatePadding(nvinfer1::Dims inputShape, nvinfer1::Dims outputShape, nvinfer1::Dims kernelSize,
nvinfer1::Dims strides, nvinfer1::Dims dilations, const int nbSpatialDims, nvinfer1::Dims& begPadding,
nvinfer1::Dims& endPadding, nvinfer1::Dims& outputPadding, nvinfer1::PaddingMode paddingMode);
// Helper function to get default ONNX activation alpha values
float getActivationDefaultAlpha(nvinfer1::ActivationType type);
// Helper function to get default ONNX activation beta values
float getActivationDefaultBeta(nvinfer1::ActivationType type);
// Helper function to get the length of the specified axis
nvinfer1::ITensor* getAxisLength(
IImporterContext* ctx, nvinfer1::ITensor* inpTensor, int axis, nvinfer1::Dims shape = nvinfer1::Dims{0});
// Helper function to calculate the output size of a convolution node given its attributes
int getConvOutputSize(int input_size, int filter_size, int stride, int dilation_rate, int total_padding);
// Helper function to get the TRT datatype given an ONNX datatype
const char* getDtypeName(int32_t onnxDtype);
// Helper function to get kernel attributes for various ONNX nodes
void getKernelParams(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& onnx_node, nvinfer1::Dims* kernel_size,
nvinfer1::Dims* strides, nvinfer1::Dims* beg_padding, nvinfer1::Dims* end_padding,
nvinfer1::PaddingMode& paddingMode, bool& count_exclude_padding, nvinfer1::Dims* dilations = nullptr,
nvinfer1::Dims* output_padding = nullptr, const bool poolingCeilMode = false);
// Helper function to get the scaling mode for TRT's scale layer
nvinfer1::ScaleMode getScaleMode(nvinfer1::Dims const& weights_shape, nvinfer1::Dims const& tensor_shape);
// Helper function to map ONNX Global Pooling ops into TensorRT.
nvinfer1::ITensor* globalPoolingHelper(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, nvinfer1::ITensor& tensor, nvinfer1::ReduceOperation op);
// Helper function to determine if a shape contains dynamic dimensions
bool isDynamic(const nvinfer1::Dims& shape);
// Helper function to determine if a ONNX tensor is empty
bool isOnnxTensorEmpty(const ::ONNX_NAMESPACE::TensorProto& onnxTensor);
// Helper function to load a creator from the registry
nvinfer1::IPluginCreator* importPluginCreator(
const std::string& pluginName, const std::string& pluginVersion, const std::string& pluginNamespace = "");
// Helper function to get a plugin from the PluginRegistry
std::unique_ptr<nvinfer1::IPluginV2, PluginDeleter> createPlugin(const std::string& name,
nvinfer1::IPluginCreator* pluginCreator, const std::vector<nvinfer1::PluginField>& pluginFields);
// Helper function to determine if a transpose is required
bool isTransposeRequired(nvinfer1::Dims const& shape, nvinfer1::Permutation const& perm);
// Helper function to import LSTM ops through the legacy CUDNN path
NodeImportResult lstmLegacyImporter(
IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs);
// Helper function to create and fill a Dims object with defined values
nvinfer1::Dims makeDims(int nbDims, int val);
// Helper function to parse activation values for LSTM nodes
std::vector<float> parseLSTMActivationValues(const std::vector<nvinfer1::ActivationType>& activationTypes,
const std::vector<float>& activationValues, bool isAlpha);
// Helper function to read weights from an external file
bool parseExternalWeights(IImporterContext* ctx, std::string file, std::string path, int64_t offset, int64_t length,
std::vector<char>& weightsBuf, size_t& size);
// Helper function to map various ONNX pooling ops into TensorRT.
NodeImportResult poolingHelper(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::PoolingType type);
// Helper function to import reduce ops into TRT
NodeImportResult reduceTensor(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, TensorOrWeights input,
nvinfer1::ReduceOperation operation, TensorOrWeights inputAxes = TensorOrWeights());
// Helper function to shape a Tensor given a new shape
nvinfer1::ITensor* reshapeTensor(IImporterContext* ctx, nvinfer1::ITensor& tensor, nvinfer1::Dims shape);
// Helper function to map attributes to a TRT scale layer
NodeImportResult scaleHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor_,
nvinfer1::ScaleMode mode, const nvinfer1::Weights& shift, const nvinfer1::Weights& scale,
const nvinfer1::Weights& power, const char* shiftName, const char* scaleName);
// Helper function to set an ONNX attribute
void setAttr(
nvinfer1::Dims* trtAttr, ::ONNX_NAMESPACE::AttributeProto const* onnxAttr, int nbSpatialDims, int defaultVal);
// Helper function to slice away elements on a given axis dimension
nvinfer1::ITensor* sliceAcrossAxis(
IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* data, const int axis);
// Helper function to filter out shape tensor outputs for layers that do not support it
bool supportsShapeTensor(nvinfer1::LayerType type, nvinfer1::ElementWiseOperation eleOp,
nvinfer1::ReduceOperation redOp, nvinfer1::FillOperation fillOp);
// Helper function to squeeze a tensor on a given set of axes
nvinfer1::ITensor* squeezeTensor(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor, const std::vector<int>& axes, bool regLayer = false);
// Helper function to transpose a tensor given a permutation
nvinfer1::ITensor* transposeTensor(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
nvinfer1::ITensor& tensor, nvinfer1::Permutation const& perm);
// Helper function to import ONNX unary ops into TRT
NodeImportResult unaryHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, TensorOrWeights& input,
nvinfer1::UnaryOperation op);
// Helper function to unsqueeze tensors on a given set of axes
nvinfer1::ITensor* unsqueezeTensor(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
nvinfer1::ITensor& tensor, const std::vector<int>& axes, bool regLayer = false);
// Helper function to convert a ShapedWeights object into a vector
template <typename WeightType>
Status weightsToVector(TensorOrWeights weights, std::vector<WeightType>* weightVector)
{
ASSERT(weights.is_weights(), ErrorCode::kUNSUPPORTED_NODE);
ASSERT((weights.weights().type == ::ONNX_NAMESPACE::TensorProto::INT32)
|| (weights.weights().type == ::ONNX_NAMESPACE::TensorProto::INT64)
|| (weights.weights().type == ::ONNX_NAMESPACE::TensorProto::BOOL),
ErrorCode::kINVALID_NODE);
weightVector->resize(weights.weights().count());
if (weights.weights().type == ::ONNX_NAMESPACE::TensorProto::INT64)
{
auto array_start = static_cast<int64_t*>(weights.weights().values);
std::copy(array_start, array_start + weights.weights().count(), weightVector->begin());
}
else if (weights.weights().type == ::ONNX_NAMESPACE::TensorProto::INT32)
{
auto array_start = static_cast<int32_t*>(weights.weights().values);
std::copy(array_start, array_start + weights.weights().count(), weightVector->begin());
}
else if (weights.weights().type == ::ONNX_NAMESPACE::TensorProto::BOOL)
{
auto array_start = static_cast<bool*>(weights.weights().values);
std::copy(array_start, array_start + weights.weights().count(), weightVector->begin());
}
return Status(ErrorCode::kSUCCESS);
}
// Helper function to convert ONNX node name. If no node name, using name of first output.
const std::string getNodeName(const ::ONNX_NAMESPACE::NodeProto& node);
//! Decode in place the starts and ends indices according to ONNX Slice rules.
void decodeOnnxStartsAndEnds(IImporterContext* ctx, const ShapeTensor& inputDims, const ShapeTensor& steps, ShapeTensor& starts, ShapeTensor& ends);
//! Return ShapeTensor representing size of result of Slice.
//! starts and ends should first be decoded by decodeOnnxStartsAndEnds.
ShapeTensor computeSliceSizes(IImporterContext* ctx, const ShapeTensor& starts, const ShapeTensor& ends,
const ShapeTensor& steps, const ShapeTensor& dims);
//! Return subscripts such that gather(concat(x,y),subscripts)
//! will return x with x[subcripts[i]] replaced by y[i].
ShapeTensor axesToInterlaceSubscripts(const ShapeTensor& axes, int nbDims);
//! Helper function to add SoftMax layer.
nvinfer1::ITensor* addSoftmax(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& input);
// Helper function to import ONNX scatter nodes into TRT
NodeImportResult addScatterLayer(
IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, std::vector<TensorOrWeights>& inputs, nvinfer1::ScatterMode mode, int32_t axis = 0);
} // namespace onnx2trt