Skip to content

Commit 5250bad

Browse files
committed
Fix setBindingDimensions() for TensorrtAPI #181
1 parent f6b5ac2 commit 5250bad

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

engine/src/nn/tensorrtapi.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,9 @@ void TensorrtAPI::bind_executor()
149149
{
150150
// create an exectution context for applying inference
151151
context = SampleUniquePtr<nvinfer1::IExecutionContext>(engine->createExecutionContext());
152-
context->setBindingDimensions(0, Dims4(batchSize,nnDesign.inputShape.v[1],nnDesign.inputShape.v[2],nnDesign.inputShape.v[3]));
152+
Dims inputDims;
153+
set_dims(inputDims, nnDesign.inputShape);
154+
context->setBindingDimensions(0, inputDims);
153155

154156
// create buffers object with respect to the engine and batch size
155157
CHECK(cudaStreamCreate(&stream));
@@ -233,9 +235,12 @@ ICudaEngine* TensorrtAPI::create_cuda_engine_from_onnx()
233235
set_config_settings(config, 1_GiB, calibrator, calibrationStream);
234236

235237
IOptimizationProfile* profile = builder->createOptimizationProfile();
236-
profile->setDimensions(nnDesign.inputLayerName.c_str(), OptProfileSelector::kMIN, network->getInput(0)->getDimensions());
237-
profile->setDimensions(nnDesign.inputLayerName.c_str(), OptProfileSelector::kOPT, network->getInput(0)->getDimensions());
238-
profile->setDimensions(nnDesign.inputLayerName.c_str(), OptProfileSelector::kMAX, network->getInput(0)->getDimensions());
238+
239+
Dims inputDims = network->getInput(0)->getDimensions();
240+
inputDims.d[0] = batchSize;
241+
profile->setDimensions(nnDesign.inputLayerName.c_str(), OptProfileSelector::kMIN, inputDims);
242+
profile->setDimensions(nnDesign.inputLayerName.c_str(), OptProfileSelector::kOPT, inputDims);
243+
profile->setDimensions(nnDesign.inputLayerName.c_str(), OptProfileSelector::kMAX, inputDims);
239244
config->addOptimizationProfile(profile);
240245

241246
// build an engine from the TensorRT network with a given configuration struct
@@ -431,4 +436,12 @@ void set_shape(nn_api::Shape &shape, const Dims &dims)
431436
}
432437
}
433438

439+
void set_dims(Dims &dims, const nn_api::Shape &shape)
440+
{
441+
dims.nbDims = shape.nbDims;
442+
for (int idx = 0; idx < shape.nbDims; ++idx) {
443+
dims.d[idx] = shape.v[idx];
444+
}
445+
}
446+
434447
#endif

engine/src/nn/tensorrtapi.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,13 @@ string generate_trt_file_path(const string &modelDirectory, unsigned int batchSi
194194
*/
195195
void set_shape(nn_api::Shape& shape, const nvinfer1::Dims& dims);
196196

197+
/**
198+
* @brief set_dims Converter function from nn_api::Shape to nvinfer1::Dims
199+
* @param dims Dims object to be set
200+
* @param shape Target object
201+
*/
202+
void set_dims(Dims &dims, const nn_api::Shape &shape);
203+
197204
#endif
198205

199206
#endif // TENSORRTAPI_H

0 commit comments

Comments
 (0)