@@ -94,12 +94,12 @@ TRTEngine::TRTEngine(
94
94
if (get_streamable_weights_size () > 0 ) {
95
95
// Scratch memory size may change based on the current weight streaming budget
96
96
// Required memory for full streaming is used to minimum weight budget
97
- set_device_memory_budget (0 );
97
+ cuda_engine-> setWeightStreamingBudgetV2 (0 );
98
98
min_required_device_budget = cuda_engine->getWeightStreamingScratchMemorySize ();
99
99
100
100
int64_t budget_bytes = get_weight_streaming_automatic_budget ();
101
101
LOG_INFO (" Set automatic weight streaming budget bytes " << budget_bytes);
102
- set_device_memory_budget (budget_bytes);
102
+ cuda_engine-> setWeightStreamingBudgetV2 (budget_bytes);
103
103
}
104
104
105
105
exec_ctx = make_trt (cuda_engine->createExecutionContext ());
@@ -276,7 +276,20 @@ int64_t TRTEngine::get_device_memory_budget() {
276
276
}
277
277
278
278
bool TRTEngine::set_device_memory_budget (int64_t budget) {
279
- return cuda_engine->setWeightStreamingBudgetV2 (budget);
279
+ // Recreating the context because weight streaming budget cannot be modified while there are active context.
280
+ if (exec_ctx.get () != nullptr ) {
281
+ exec_ctx.reset ();
282
+ }
283
+ if (profile_execution) {
284
+ trt_engine_profiler.reset ();
285
+ }
286
+ bool result = cuda_engine->setWeightStreamingBudgetV2 (budget);
287
+ exec_ctx = make_trt (cuda_engine->createExecutionContext ());
288
+ TORCHTRT_CHECK ((exec_ctx.get () != nullptr ), " Unable to recreate TensorRT execution context" );
289
+ if (profile_execution) {
290
+ enable_profiling ();
291
+ }
292
+ return result;
280
293
}
281
294
282
295
// Returns 0 if BuilderFlag::kWEIGHT_STREAMING is unset during engine building.
@@ -292,26 +305,6 @@ int64_t TRTEngine::get_weight_streaming_automatic_budget() {
292
305
return cuda_engine->getWeightStreamingAutomaticBudget ();
293
306
}
294
307
295
- void TRTEngine::init_context () {
296
- if (exec_ctx.get () == nullptr ) {
297
- exec_ctx = make_trt (cuda_engine->createExecutionContext ());
298
- TORCHTRT_CHECK ((exec_ctx.get () != nullptr ), " Unable to recreate TensorRT execution context" );
299
- if (profile_execution) {
300
- enable_profiling ();
301
- }
302
- }
303
- }
304
-
305
- void TRTEngine::reset_context () {
306
- if (exec_ctx.get () != nullptr ) {
307
- exec_ctx.reset ();
308
- exec_ctx = nullptr ;
309
- }
310
- if (profile_execution) {
311
- trt_engine_profiler.reset ();
312
- }
313
- }
314
-
315
308
std::string TRTEngine::to_str () const {
316
309
// clang-format off
317
310
std::stringstream ss;
0 commit comments