Skip to content

Commit

Permalink
chore: uniform name using *device_memory_budget
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Oct 2, 2024
1 parent 3e58bc8 commit 7c9cc49
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 34 deletions.
14 changes: 8 additions & 6 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ TRTEngine::TRTEngine(
cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()));
TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");

if (get_streamable_weights_size() > 0) {
int64_t budget_bytes = get_weight_streaming_automatic_budget();
LOG_INFO("Set automatic weight streaming budget bytes " << budget_bytes);
if (get_streamable_device_memory_budget() > 0) {
int64_t budget_bytes = get_automatic_device_memory_budget();
LOG_DEBUG("Weight streaming budget set to " << budget_bytes << "B");
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
}

Expand Down Expand Up @@ -280,19 +280,21 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
}
bool result = cuda_engine->setWeightStreamingBudgetV2(budget);
exec_ctx = make_trt(cuda_engine->createExecutionContext());
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context");
TORCHTRT_CHECK(
(exec_ctx.get() != nullptr),
"Unable to recreate TensorRT execution context after setting new device memory budget");
if (profile_execution) {
enable_profiling();
}
return result;
}

// Returns 0 if BuilderFlag::kWEIGHT_STREAMING is unset during engine building.
int64_t TRTEngine::get_streamable_weights_size() {
int64_t TRTEngine::get_streamable_device_memory_budget() {
return cuda_engine->getStreamableWeightsSize();
}

int64_t TRTEngine::get_weight_streaming_automatic_budget() {
int64_t TRTEngine::get_automatic_device_memory_budget() {
return cuda_engine->getWeightStreamingAutomaticBudget();
}

Expand Down
4 changes: 2 additions & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ struct TRTEngine : torch::CustomClassHolder {
void dump_engine_layer_info();
int64_t get_device_memory_budget();
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_weights_size();
int64_t get_weight_streaming_automatic_budget();
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';

Expand Down
4 changes: 2 additions & 2 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
"device_memory_budget",
&TRTEngine::get_device_memory_budget,
&TRTEngine::set_device_memory_budget)
.def_property("streamable_weights_size", &TRTEngine::get_streamable_weights_size)
.def_property("weight_streaming_automatic_budget", &TRTEngine::get_weight_streaming_automatic_budget)
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
// Serialize TensorRT engine
Expand Down
18 changes: 9 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/impl/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ def matrix_multiply(
input, other = broadcast(
ctx, input, other, f"{name}_input", f"{name}_other", preset_diff
)

promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(input.dtype).to(torch.dtype),
_enums.dtype._from(other.dtype).to(torch.dtype),
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(input.dtype).to(torch.dtype),
_enums.dtype._from(other.dtype).to(torch.dtype),
)
)
)
trt_promoted_type = promoted_type.to(trt.DataType)
input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted")
other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted")
trt_promoted_type = promoted_type.to(trt.DataType)
input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted")
other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted")

layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
set_layer_name(layer, target, name, source_ir)
Expand Down
15 changes: 8 additions & 7 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def __init__(
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

def get_streamable_weights_size(self) -> Any:
def get_streamable_device_memory_budget(self) -> Any:
return self.engine.streamable_weights_size

def get_automatic_weight_streaming_budget(self) -> Any:
def get_automatic_device_memory_budget(self) -> Any:
return self.engine.get_weight_streaming_automatic_budget()

def get_device_memory_budget(self) -> Any:
Expand All @@ -131,19 +131,20 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
def _set_device_memory_budget(self, budget_bytes: int) -> int:
# Disable weight streaming for invalid budget size
if budget_bytes < 0:
budget_bytes = self.get_streamable_weights_size()
budget_bytes = self.get_streamable_device_memory_budget()
self.engine.weight_streaming_budget_v2 = budget_bytes
if self.engine.weight_streaming_budget_v2 != budget_bytes:
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
budget_bytes = self.engine.weight_streaming_budget_v2
if self.engine.streamable_weights_size == budget_bytes:
if self.get_streamable_device_memory_budget() == budget_bytes:
logger.warning("Weight streaming is disabled")

return budget_bytes

def set_default_streaming_budget(self) -> int:
budget_bytes = self.get_automatic_weight_streaming_budget()
def set_default_device_memory_budget(self) -> int:
budget_bytes = self.get_automatic_device_memory_budget()
# Set automatic weight streaming budget as default when context is created
logger.debug(f"Weight streaming budget set to {budget_bytes}B")
return self._set_device_memory_budget(budget_bytes)

def setup_engine(self) -> None:
Expand All @@ -155,7 +156,7 @@ def setup_engine(self) -> None:
runtime = trt.Runtime(TRT_LOGGER)
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
if self.settings.enable_weight_streaming:
self.set_default_streaming_budget()
self.set_default_device_memory_budget()
self.context = self.engine.create_execution_context()

assert self.engine.num_io_tensors == (
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,24 @@ def _pack_engine_info(self) -> List[str | bytes]:

return engine_info

def get_streamable_weights_size(self) -> Any:
return self.engine.streamable_weights_size
def get_streamable_device_memory_budget(self) -> Any:
return self.engine.streamable_device_memory_budget

def get_automatic_weight_streaming_budget(self) -> Any:
return self.engine.weight_streaming_automatic_budget
def get_automatic_device_memory_budget(self) -> Any:
return self.engine.automatic_device_memory_budget

def get_device_memory_budget(self) -> Any:
return self.engine.device_memory_budget

def set_device_memory_budget(self, budget_bytes: int) -> int:
# Disable weight streaming for invalid budget size
if budget_bytes < 0:
budget_bytes = self.get_streamable_weights_size()
budget_bytes = self.get_streamable_device_memory_budget()
self.engine.device_memory_budget = budget_bytes
if self.engine.device_memory_budget != budget_bytes:
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
budget_bytes = self.engine.device_memory_budget
if self.get_streamable_weights_size() == budget_bytes:
if self.get_streamable_device_memory_budget() == budget_bytes:
logger.warning("Weight streaming is disabled")

return budget_bytes
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, module: torch.fx.GraphModule) -> None:
rt_mods.append((name, rt_mod))
self.current_device_budget += rt_mod.get_device_memory_budget()
self.streamable_budget = [
mod.get_streamable_weights_size() for _, mod in rt_mods
mod.get_streamable_device_memory_budget() for _, mod in rt_mods
]
self.rt_mods = rt_mods
total_device_budget = sum(self.streamable_budget)
Expand All @@ -32,7 +32,7 @@ def __init__(self, module: torch.fx.GraphModule) -> None:
def get_automatic_weight_streaming_budget(self) -> int:
ws_budget_bytes = 0
for _, rt_mod in self.rt_mods:
ws_budget_bytes += rt_mod.get_automatic_weight_streaming_budget()
ws_budget_bytes += rt_mod.get_automatic_device_memory_budget()
return ws_budget_bytes

def __enter__(self) -> "_WeightStreamingContextManager":
Expand Down

0 comments on commit 7c9cc49

Please sign in to comment.