Skip to content

Commit 493b0a6

Browse files
committed
chore: reset context in weight budget setting
1 parent dde5e3c commit 493b0a6

File tree

9 files changed

+30
-77
lines changed

9 files changed

+30
-77
lines changed

core/runtime/TRTEngine.cpp

+16-23
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ TRTEngine::TRTEngine(
9494
if (get_streamable_weights_size() > 0) {
9595
// Scratch memory size may change based on the current weight streaming budget
9696
// Required memory for full streaming is used to minimum weight budget
97-
set_device_memory_budget(0);
97+
cuda_engine->setWeightStreamingBudgetV2(0);
9898
min_required_device_budget = cuda_engine->getWeightStreamingScratchMemorySize();
9999

100100
int64_t budget_bytes = get_weight_streaming_automatic_budget();
101101
LOG_INFO("Set automatic weight streaming budget bytes " << budget_bytes);
102-
set_device_memory_budget(budget_bytes);
102+
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
103103
}
104104

105105
exec_ctx = make_trt(cuda_engine->createExecutionContext());
@@ -276,7 +276,20 @@ int64_t TRTEngine::get_device_memory_budget() {
276276
}
277277

278278
bool TRTEngine::set_device_memory_budget(int64_t budget) {
279-
return cuda_engine->setWeightStreamingBudgetV2(budget);
279+
// Recreating the context because weight streaming budget cannot be modified while there are active context.
280+
if (exec_ctx.get() != nullptr) {
281+
exec_ctx.reset();
282+
}
283+
if (profile_execution) {
284+
trt_engine_profiler.reset();
285+
}
286+
bool result = cuda_engine->setWeightStreamingBudgetV2(budget);
287+
exec_ctx = make_trt(cuda_engine->createExecutionContext());
288+
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context");
289+
if (profile_execution) {
290+
enable_profiling();
291+
}
292+
return result;
280293
}
281294

282295
// Returns 0 if BuilderFlag::kWEIGHT_STREAMING is unset during engine building.
@@ -292,26 +305,6 @@ int64_t TRTEngine::get_weight_streaming_automatic_budget() {
292305
return cuda_engine->getWeightStreamingAutomaticBudget();
293306
}
294307

295-
void TRTEngine::init_context() {
296-
if (exec_ctx.get() == nullptr) {
297-
exec_ctx = make_trt(cuda_engine->createExecutionContext());
298-
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context");
299-
if (profile_execution) {
300-
enable_profiling();
301-
}
302-
}
303-
}
304-
305-
void TRTEngine::reset_context() {
306-
if (exec_ctx.get() != nullptr) {
307-
exec_ctx.reset();
308-
exec_ctx = nullptr;
309-
}
310-
if (profile_execution) {
311-
trt_engine_profiler.reset();
312-
}
313-
}
314-
315308
std::string TRTEngine::to_str() const {
316309
// clang-format off
317310
std::stringstream ss;

core/runtime/TRTEngine.h

-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ struct TRTEngine : torch::CustomClassHolder {
7676
int64_t get_streamable_weights_size();
7777
int64_t get_min_required_device_budget();
7878
int64_t get_weight_streaming_automatic_budget();
79-
void init_context();
80-
void reset_context();
8179
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
8280
static const char BINDING_DELIM = '%';
8381

core/runtime/register_jit_hooks.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
9393
.def_property("streamable_weights_size", &TRTEngine::get_streamable_weights_size)
9494
.def_property("min_required_device_budget", &TRTEngine::get_min_required_device_budget)
9595
.def_property("weight_streaming_automatic_budget", &TRTEngine::get_weight_streaming_automatic_budget)
96-
.def("init_context", &TRTEngine::init_context)
97-
.def("reset_context", &TRTEngine::reset_context)
9896
.def_pickle(
9997
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
10098
// Serialize TensorRT engine

py/torch_tensorrt/dynamo/_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def compile(
8888
engine_cache_dir: Optional[str] = _defaults.ENGINE_CACHE_DIR,
8989
engine_cache_size: Optional[int] = _defaults.ENGINE_CACHE_SIZE,
9090
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
91-
enable_weight_streaming: bool = _defaults.WEIGHT_STREAMING,
91+
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
9292
**kwargs: Any,
9393
) -> torch.fx.GraphModule:
9494
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT

py/torch_tensorrt/dynamo/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
4141
ENGINE_CACHE_SIZE = 1073741824
4242
CUSTOM_ENGINE_CACHE = None
43-
WEIGHT_STREAMING = False
43+
ENABLE_WEIGHT_STREAMING = False
4444

4545

4646
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DLA_SRAM_SIZE,
1515
DRYRUN,
1616
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
17+
ENABLE_WEIGHT_STREAMING,
1718
ENABLED_PRECISIONS,
1819
ENGINE_CAPABILITY,
1920
HARDWARE_COMPATIBLE,
@@ -32,7 +33,6 @@
3233
USE_FAST_PARTITIONER,
3334
USE_PYTHON_RUNTIME,
3435
VERSION_COMPATIBLE,
35-
WEIGHT_STREAMING,
3636
WORKSPACE_SIZE,
3737
default_device,
3838
)
@@ -114,4 +114,4 @@ class CompilationSettings:
114114
lazy_engine_init: bool = LAZY_ENGINE_INIT
115115
cache_built_engines: bool = CACHE_BUILT_ENGINES
116116
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
117-
enable_weight_streaming: bool = WEIGHT_STREAMING
117+
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
from torch_tensorrt._Device import Device
1313
from torch_tensorrt._enums import Platform, dtype
1414
from torch_tensorrt.dynamo._settings import CompilationSettings
15-
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
16-
recreate_context_decorator,
17-
)
1815
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1916
from torch_tensorrt.logging import TRT_LOGGER
2017
from torch_tensorrt.runtime._utils import (
@@ -115,16 +112,6 @@ def __init__(
115112
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
116113
self.setup_engine()
117114

118-
def init_context(self) -> None:
119-
assert self.engine, "Context is used before setting up the engine"
120-
if self.context is None:
121-
self.context = self.engine.create_execution_context()
122-
123-
def reset_context(self) -> None:
124-
if self.context is not None:
125-
del self.context
126-
self.context = None
127-
128115
def get_streamable_weights_size(self) -> Any:
129116
return self.engine.streamable_weights_size
130117

@@ -137,9 +124,13 @@ def get_weight_streaming_budget(self) -> Any:
137124
def get_automatic_weight_streaming_budget(self) -> Any:
138125
return self.engine.get_weight_streaming_automatic_budget()
139126

140-
@recreate_context_decorator
141127
def set_device_memory_budget(self, budget_bytes: int) -> int:
142-
return self._set_device_memory_budget(budget_bytes)
128+
# Recreating the context because weight streaming budget cannot be modified while there are active context.
129+
if self.context is not None:
130+
del self.context
131+
budget_bytes = self._set_device_memory_budget(budget_bytes)
132+
self.context = self.engine.create_execution_context()
133+
return budget_bytes
143134

144135
def _set_device_memory_budget(self, budget_bytes: int) -> int:
145136
# Disable weight streaming for invalid budget size

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

+1-28
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import copy
55
import logging
66
import pickle
7-
from functools import wraps
8-
from typing import Any, Callable, List, Optional, Tuple, Union
7+
from typing import Any, List, Optional, Tuple, Union
98

109
import torch
1110
from torch_tensorrt._Device import Device
@@ -50,22 +49,6 @@
5049
SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 9
5150

5251

53-
def recreate_context_decorator(method: Callable[..., Any]) -> Callable[..., Any]:
54-
"""
55-
A decorator that destroys a context before a method execution and
56-
creates it after the method execution within the same class instance.
57-
"""
58-
59-
@wraps(method)
60-
def wrapper(self: object, *args: Any, **kwargs: Any) -> Any:
61-
self.reset_context()
62-
result = method(self, *args, **kwargs)
63-
self.init_context()
64-
return result
65-
66-
return wrapper
67-
68-
6952
@for_all_methods(needs_torch_tensorrt_runtime)
7053
class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
7154
"""TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
@@ -186,12 +169,6 @@ def _pack_engine_info(self) -> List[str | bytes]:
186169

187170
return engine_info
188171

189-
def init_context(self) -> None:
190-
self.engine.init_context()
191-
192-
def reset_context(self) -> None:
193-
self.engine.reset_context()
194-
195172
def get_streamable_weights_size(self) -> Any:
196173
return self.engine.streamable_weights_size
197174

@@ -204,11 +181,7 @@ def get_weight_streaming_budget(self) -> Any:
204181
def get_automatic_weight_streaming_budget(self) -> Any:
205182
return self.engine.weight_streaming_automatic_budget
206183

207-
@recreate_context_decorator
208184
def set_device_memory_budget(self, budget_bytes: int) -> int:
209-
return self._set_device_memory_budget(budget_bytes)
210-
211-
def _set_device_memory_budget(self, budget_bytes: int) -> int:
212185
# Disable weight streaming for invalid budget size
213186
if budget_bytes < 0:
214187
budget_bytes = self.get_streamable_weights_size()

tests/py/dynamo/runtime/test_004_weight_streaming.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def test_weight_streaming_default(self, _, use_python_runtime):
4949
use_python_runtime=use_python_runtime,
5050
enable_weight_streaming=True,
5151
)
52-
# Checking default weight streaming budget(automatic) is applied
53-
with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx:
54-
assert weight_streaming_ctx.device_budget > 0
52+
# Checking if default weight streaming budget(automatic) is applied when compiler option was provided
53+
weight_streaming_ctx = torchtrt.runtime.weight_streaming(optimized_model)
54+
assert weight_streaming_ctx.device_budget > 0
5555

5656
ref = model(*input)
5757
out = optimized_model(*input)

0 commit comments

Comments
 (0)