@@ -149,7 +149,9 @@ void TensorrtAPI::bind_executor()
149
149
{
150
150
// create an exectution context for applying inference
151
151
context = SampleUniquePtr<nvinfer1::IExecutionContext>(engine->createExecutionContext ());
152
- context->setBindingDimensions (0 , Dims4 (batchSize,nnDesign.inputShape .v [1 ],nnDesign.inputShape .v [2 ],nnDesign.inputShape .v [3 ]));
152
+ Dims inputDims;
153
+ set_dims (inputDims, nnDesign.inputShape );
154
+ context->setBindingDimensions (0 , inputDims);
153
155
154
156
// create buffers object with respect to the engine and batch size
155
157
CHECK (cudaStreamCreate (&stream));
@@ -233,9 +235,12 @@ ICudaEngine* TensorrtAPI::create_cuda_engine_from_onnx()
233
235
set_config_settings (config, 1_GiB, calibrator, calibrationStream);
234
236
235
237
IOptimizationProfile* profile = builder->createOptimizationProfile ();
236
- profile->setDimensions (nnDesign.inputLayerName .c_str (), OptProfileSelector::kMIN , network->getInput (0 )->getDimensions ());
237
- profile->setDimensions (nnDesign.inputLayerName .c_str (), OptProfileSelector::kOPT , network->getInput (0 )->getDimensions ());
238
- profile->setDimensions (nnDesign.inputLayerName .c_str (), OptProfileSelector::kMAX , network->getInput (0 )->getDimensions ());
238
+
239
+ Dims inputDims = network->getInput (0 )->getDimensions ();
240
+ inputDims.d [0 ] = batchSize;
241
+ profile->setDimensions (nnDesign.inputLayerName .c_str (), OptProfileSelector::kMIN , inputDims);
242
+ profile->setDimensions (nnDesign.inputLayerName .c_str (), OptProfileSelector::kOPT , inputDims);
243
+ profile->setDimensions (nnDesign.inputLayerName .c_str (), OptProfileSelector::kMAX , inputDims);
239
244
config->addOptimizationProfile (profile);
240
245
241
246
// build an engine from the TensorRT network with a given configuration struct
@@ -431,4 +436,12 @@ void set_shape(nn_api::Shape &shape, const Dims &dims)
431
436
}
432
437
}
433
438
439
+ void set_dims (Dims &dims, const nn_api::Shape &shape)
440
+ {
441
+ dims.nbDims = shape.nbDims ;
442
+ for (int idx = 0 ; idx < shape.nbDims ; ++idx) {
443
+ dims.d [idx] = shape.v [idx];
444
+ }
445
+ }
446
+
434
447
#endif
0 commit comments