Skip to content

Commit

Permalink
chore: cpp runtime update for min + streamable budget
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Sep 3, 2024
1 parent f999bbf commit 99f76f7
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 155 deletions.
13 changes: 11 additions & 2 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ TRTEngine::TRTEngine(
cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()));
TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");

if (get_min_required_device_budget() > 0) {
if (get_streamable_weights_size() > 0) {
// Scratch memory size may change based on the current weight streaming budget
// Required memory for full streaming is used to minimum weight budget
set_device_memory_budget(0);
min_required_device_budget = cuda_engine->getWeightStreamingScratchMemorySize();

int64_t budget_bytes = get_weight_streaming_automatic_budget();
LOG_INFO("Set automatic weight streaming budget bytes " << budget_bytes);
set_device_memory_budget(budget_bytes);
Expand Down Expand Up @@ -275,10 +280,14 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
}

// Returns 0 if BuilderFlag::kWEIGHT_STREAMING is unset during engine building.
int64_t TRTEngine::get_min_required_device_budget() {
int64_t TRTEngine::get_streamable_weights_size() {
return cuda_engine->getStreamableWeightsSize();
}

int64_t TRTEngine::get_min_required_device_budget() {
return min_required_device_budget;
}

int64_t TRTEngine::get_weight_streaming_automatic_budget() {
return cuda_engine->getWeightStreamingAutomaticBudget();
}
Expand Down
2 changes: 2 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct TRTEngine : torch::CustomClassHolder {
void dump_engine_layer_info();
int64_t get_device_memory_budget();
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_weights_size();
int64_t get_min_required_device_budget();
int64_t get_weight_streaming_automatic_budget();
void init_context();
Expand Down Expand Up @@ -106,6 +107,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::string cuda_graph_debug_path;
std::mutex mu;
std::unique_ptr<TRTEngineProfiler> trt_engine_profiler;
int64_t min_required_device_budget;
};

} // namespace runtime
Expand Down
3 changes: 2 additions & 1 deletion core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
"device_memory_budget",
&TRTEngine::get_device_memory_budget,
&TRTEngine::set_device_memory_budget)
.def_property("streamable_weights_size", &TRTEngine::get_streamable_weights_size)
.def_property("min_required_device_budget", &TRTEngine::get_min_required_device_budget)
.def("get_weight_streaming_automatic_budget", &TRTEngine::get_weight_streaming_automatic_budget)
.def_property("weight_streaming_automatic_budget", &TRTEngine::get_weight_streaming_automatic_budget)
.def("init_context", &TRTEngine::init_context)
.def("reset_context", &TRTEngine::reset_context)
.def_pickle(
Expand Down
32 changes: 12 additions & 20 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import logging
from contextlib import nullcontext
from functools import wraps
from tempfile import tempdir
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
import torch
Expand All @@ -13,6 +12,9 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import Platform, dtype
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
recreate_context_decorator,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.logging import TRT_LOGGER
from torch_tensorrt.runtime._utils import (
Expand All @@ -24,22 +26,6 @@
logger = logging.getLogger(__name__)


def recreate_context_decorator(method: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator that destroys a context before a method execution and
creates it after the method execution within the same class instance.
"""

@wraps(method)
def wrapper(self: object, *args: Any, **kwargs: Any) -> Any:
self.reset_context()
result = method(self, *args, **kwargs)
self.init_context()
return result

return wrapper


class PythonTorchTensorRTModule(Module): # type: ignore[misc]
"""PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
Expand Down Expand Up @@ -142,17 +128,23 @@ def reset_context(self) -> None:
def get_streamable_weights_size(self) -> Any:
return self.engine.streamable_weights_size

def get_min_required_device_budget(self) -> Any:
return self.min_required_device_budget

def get_weight_streaming_budget(self) -> Any:
return self.engine.weight_streaming_budget_v2

def get_automatic_weight_streaming_budget(self) -> Any:
return self.engine.get_weight_streaming_automatic_budget()

@recreate_context_decorator
def set_device_memory_budget(self, budget_bytes: int) -> int:
return self._set_device_memory_budget(budget_bytes)

def _set_device_memory_budget(self, budget_bytes: int) -> int:
# Disable weight streaming for invalid budget size
if budget_bytes < 0:
budget_bytes = self.engine.streamable_weights_size
budget_bytes = self.get_streamable_weights_size()
self.engine.weight_streaming_budget_v2 = budget_bytes
if self.get_weight_streaming_budget() != budget_bytes:
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
Expand All @@ -169,7 +161,7 @@ def set_default_streaming_budget(self) -> int:
self.min_required_device_budget = (
self.engine.weight_streaming_scratch_memory_size
)
budget_bytes = self.engine.get_weight_streaming_automatic_budget()
budget_bytes = self.get_automatic_weight_streaming_budget()
# Set automatic weight streaming budget as default when context is created
return self._set_device_memory_budget(budget_bytes)

Expand Down
15 changes: 8 additions & 7 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,34 +192,35 @@ def init_context(self) -> None:
def reset_context(self) -> None:
self.engine.reset_context()

def get_streamable_weights_size(self) -> Any:
return self.engine.streamable_weights_size

def get_min_required_device_budget(self) -> Any:
return self.engine.min_required_device_budget

def get_weight_streaming_budget(self) -> Any:
return self.engine.device_memory_budget

def get_automatic_weight_streaming_budget(self) -> Any:
return self.engine.weight_streaming_automatic_budget

@recreate_context_decorator
def set_device_memory_budget(self, budget_bytes: int) -> int:
return self._set_device_memory_budget(budget_bytes)

def _set_device_memory_budget(self, budget_bytes: int) -> int:
# Disable weight streaming for invalid budget size
if budget_bytes < 0:
budget_bytes = self.get_min_required_device_budget()

budget_bytes = self.get_streamable_weights_size()
self.engine.device_memory_budget = budget_bytes
if self.get_weight_streaming_budget() != budget_bytes:
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
budget_bytes = self.get_weight_streaming_budget()
if self.engine.min_required_device_budget == budget_bytes:
if self.get_min_required_device_budget() == budget_bytes:
logger.warning("Weight streaming is disabled")

return budget_bytes

def set_automatic_streaming_budget(self) -> int:
budget_bytes = self.engine.get_weight_streaming_automatic_budget()
return self._set_device_memory_budget(budget_bytes)

def setup_engine(self) -> None:
"""
Setup engine for a module which has deferred engine setup.
Expand Down
111 changes: 61 additions & 50 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,83 +15,94 @@ class _WeightStreamingContextManager(object):
def __init__(self, module: torch.fx.GraphModule) -> None:
rt_mods = []
torch_budget = 0
trt_budget = 0
trt_min_budget = 0
trt_max_budget = 0
current_trt_budget = 0
for name, rt_mod in module.named_children():
if "_run_on_acc" in name and isinstance(
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
):
trt_budget += rt_mod.min_required_device_budget
trt_budget += rt_mod.get_streamable_weights_size()
trt_min_budget += rt_mod.get_min_required_device_budget()
trt_max_budget += rt_mod.get_streamable_weights_size()
current_trt_budget += rt_mod.get_weight_streaming_budget()
rt_mods.append((name, rt_mod))
else:
torch_budget += sum(
[p.numel() * p.element_size() for p in rt_mod.parameters()]
)

trt_max_budget += trt_min_budget
self.torch_budget = torch_budget
self.rt_mods = rt_mods
total_device_budget = torch_budget + trt_budget
# device_budget is -1 if there is no trt module
device_budget = -1 if trt_budget == 0 else total_device_budget
device_budget = torch_budget + trt_min_budget + current_trt_budget
super().__setattr__("device_budget", device_budget)
super().__setattr__("total_trt_budget", trt_budget)
super().__setattr__("trt_max_budget", trt_max_budget)

def get_automatic_weight_streaming_budget(self) -> int:
ws_budget_bytes = self.torch_budget
for _, rt_mod in self.rt_mods:
ws_budget_bytes += rt_mod.get_automatic_weight_streaming_budget()
ws_budget_bytes += rt_mod.get_min_required_device_budget()
return ws_budget_bytes

def get_min_required_device_budget(self) -> int:
def get_required_device_budgets(self) -> tuple[int, int]:
min_budget = self.torch_budget
max_budget = self.torch_budget + self.trt_max_budget
for _, rt_mod in self.rt_mods:
min_budget += rt_mod.min_required_device_budget
return min_budget
min_budget += rt_mod.get_min_required_device_budget()
return min_budget, max_budget

def __enter__(self) -> "_WeightStreamingContextManager":
return self

def __exit__(self, *args: Any) -> None:
for name, rt_mod in self.rt_mods:
streamable_budget = rt_mod.get_streamable_weights_size()
rt_mod.set_device_memory_budget(streamable_budget)
max_budget = self.torch_budget + self.trt_max_budget
if self.trt_max_budget > 0:
logger.debug(
f"Disable weight streaming by setting size {streamable_budget} for {name}"
f"Disable weight streaming by applying max budget size {max_budget}"
)
self.device_budget = max_budget

def __setattr__(self, name: str, value: Any) -> None:
if name == "device_budget":
requested_budget = value
trt_engine_budget = requested_budget - self.torch_budget
value = 0
if self.total_trt_budget == 0:
logger.error(
"Streamable bytes are zero. Was module complied with enable_weight_streaming=True option?"
def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
ws_budget_bytes = self.torch_budget
trt_engine_budget = requested_budget - self.torch_budget
if self.trt_max_budget == 0:
raise RuntimeError(
"Streamable bytes are zero. Was module complied with enable_weight_streaming=True option?"
)
elif trt_engine_budget <= 0:
raise RuntimeError(
f"Requested budget {requested_budget} is less than mininum torch budget: {self.torch_budget}"
)
else:
# Normalized size is applied for multiple trt runtime module.
# e.g. 100B budget is applied to two modules and they have 1000B and 3000B max streamable size respectively.
# Then 25B and 75B are applied for each module.
for mod_name, rt_mod in self.rt_mods:
max_budget = (
rt_mod.get_min_required_device_budget()
+ rt_mod.get_streamable_weights_size()
)
value = -1
elif trt_engine_budget <= 0:
logger.error(
f"Requested budget {requested_budget} is less than mininum torch budget: {self.torch_budget}"
normalized_size = (
int(max_budget / self.trt_max_budget * trt_engine_budget)
- rt_mod.get_min_required_device_budget()
)
value = -1
else:
# Normalized size is applied for multiple trt runtime module.
# e.g. 100B budget is applied to two modules and they have 1000B and 3000B max streamable size respectively.
# Then 25B and 75B are applied for each module.
for mod_name, rt_mod in self.rt_mods:
max_budget = (
rt_mod.min_required_device_budget
+ rt_mod.get_streamable_weights_size()
)
normalized_size = (
int(max_budget / self.total_trt_budget * trt_engine_budget)
- rt_mod.min_required_device_budget
)
if normalized_size < 0:
logger.error(
f"Requested trt budget {trt_engine_budget} is less than mininum trt budget: {rt_mod.min_required_device_budget}"
)
value = -1
break
value += rt_mod.set_device_memory_budget(normalized_size)
value += rt_mod.min_required_device_budget
logger.debug(
f"Set weight streaming size {normalized_size} for {mod_name}"

if normalized_size < 0:
raise RuntimeError(
f"Requested trt budget {trt_engine_budget} is less than mininum trt budget of submodule {mod_name} size={rt_mod.get_min_required_device_budget()}"
)

ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size)
ws_budget_bytes += rt_mod.get_min_required_device_budget()
logger.debug(
f"Set weight streaming size {normalized_size} for {mod_name}"
)
return ws_budget_bytes

def __setattr__(self, name: str, value: Any) -> None:
if name == "device_budget":
value = self._set_streamable_weight_bytes(value)
super().__setattr__(name, value)


Expand Down
Loading

0 comments on commit 99f76f7

Please sign in to comment.