Skip to content

Commit

Permalink
chore: changed to budget range in [0, streamable budget]
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Sep 23, 2024
1 parent 493b0a6 commit 892517d
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 117 deletions.
9 changes: 0 additions & 9 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ TRTEngine::TRTEngine(
TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");

if (get_streamable_weights_size() > 0) {
// Scratch memory size may change based on the current weight streaming budget
// Required memory for full streaming is used to minimum weight budget
cuda_engine->setWeightStreamingBudgetV2(0);
min_required_device_budget = cuda_engine->getWeightStreamingScratchMemorySize();

int64_t budget_bytes = get_weight_streaming_automatic_budget();
LOG_INFO("Set automatic weight streaming budget bytes " << budget_bytes);
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
Expand Down Expand Up @@ -297,10 +292,6 @@ int64_t TRTEngine::get_streamable_weights_size() {
return cuda_engine->getStreamableWeightsSize();
}

int64_t TRTEngine::get_min_required_device_budget() {
return min_required_device_budget;
}

int64_t TRTEngine::get_weight_streaming_automatic_budget() {
return cuda_engine->getWeightStreamingAutomaticBudget();
}
Expand Down
2 changes: 0 additions & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ struct TRTEngine : torch::CustomClassHolder {
int64_t get_device_memory_budget();
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_weights_size();
int64_t get_min_required_device_budget();
int64_t get_weight_streaming_automatic_budget();
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';
Expand Down Expand Up @@ -105,7 +104,6 @@ struct TRTEngine : torch::CustomClassHolder {
std::string cuda_graph_debug_path;
std::mutex mu;
std::unique_ptr<TRTEngineProfiler> trt_engine_profiler;
int64_t min_required_device_budget;
};

} // namespace runtime
Expand Down
1 change: 0 additions & 1 deletion core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
&TRTEngine::get_device_memory_budget,
&TRTEngine::set_device_memory_budget)
.def_property("streamable_weights_size", &TRTEngine::get_streamable_weights_size)
.def_property("min_required_device_budget", &TRTEngine::get_min_required_device_budget)
.def_property("weight_streaming_automatic_budget", &TRTEngine::get_weight_streaming_automatic_budget)
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
Expand Down
17 changes: 2 additions & 15 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,13 @@ def __init__(
self.engine = None
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
self.min_required_device_budget = 0

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

def get_streamable_weights_size(self) -> Any:
return self.engine.streamable_weights_size

def get_min_required_device_budget(self) -> Any:
return self.min_required_device_budget

def get_weight_streaming_budget(self) -> Any:
return self.engine.weight_streaming_budget_v2

def get_automatic_weight_streaming_budget(self) -> Any:
return self.engine.get_weight_streaming_automatic_budget()

Expand All @@ -137,21 +130,15 @@ def _set_device_memory_budget(self, budget_bytes: int) -> int:
if budget_bytes < 0:
budget_bytes = self.get_streamable_weights_size()
self.engine.weight_streaming_budget_v2 = budget_bytes
if self.get_weight_streaming_budget() != budget_bytes:
if self.engine.weight_streaming_budget_v2 != budget_bytes:
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
budget_bytes = self.get_weight_streaming_budget()
budget_bytes = self.engine.weight_streaming_budget_v2
if self.engine.streamable_weights_size == budget_bytes:
logger.warning("Weight streaming is disabled")

return budget_bytes

def set_default_streaming_budget(self) -> int:
# Scratch memory size may change based on the current weight streaming budget
# Required memory for full streaming is used to minimum weight budget
self.engine.weight_streaming_budget_v2 = 0
self.min_required_device_budget = (
self.engine.weight_streaming_scratch_memory_size
)
budget_bytes = self.get_automatic_weight_streaming_budget()
# Set automatic weight streaming budget as default when context is created
return self._set_device_memory_budget(budget_bytes)
Expand Down
12 changes: 3 additions & 9 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,6 @@ def _pack_engine_info(self) -> List[str | bytes]:
def get_streamable_weights_size(self) -> Any:
return self.engine.streamable_weights_size

def get_min_required_device_budget(self) -> Any:
return self.engine.min_required_device_budget

def get_weight_streaming_budget(self) -> Any:
return self.engine.device_memory_budget

def get_automatic_weight_streaming_budget(self) -> Any:
return self.engine.weight_streaming_automatic_budget

Expand All @@ -186,10 +180,10 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
if budget_bytes < 0:
budget_bytes = self.get_streamable_weights_size()
self.engine.device_memory_budget = budget_bytes
if self.get_weight_streaming_budget() != budget_bytes:
if self.engine.device_memory_budget != budget_bytes:
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
budget_bytes = self.get_weight_streaming_budget()
if self.get_min_required_device_budget() == budget_bytes:
budget_bytes = self.engine.device_memory_budget
if self.get_streamable_weights_size() == budget_bytes:
logger.warning("Weight streaming is disabled")

return budget_bytes
Expand Down
89 changes: 30 additions & 59 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,90 +14,61 @@ class _WeightStreamingContextManager(object):

def __init__(self, module: torch.fx.GraphModule) -> None:
rt_mods = []
torch_budget = 0
trt_min_budget = 0
trt_max_budget = 0
current_trt_budget = 0
for name, rt_mod in module.named_children():
if "_run_on_acc" in name and isinstance(
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
):
trt_min_budget += rt_mod.get_min_required_device_budget()
trt_max_budget += rt_mod.get_streamable_weights_size()
current_trt_budget += rt_mod.get_weight_streaming_budget()
rt_mods.append((name, rt_mod))
else:
torch_budget += sum(
[p.numel() * p.element_size() for p in rt_mod.parameters()]
)

trt_max_budget += trt_min_budget
self.torch_budget = torch_budget
self.streamable_budget = [
mod.get_streamable_weights_size() for _, mod in rt_mods
]
self.rt_mods = rt_mods
device_budget = torch_budget + trt_min_budget + current_trt_budget
super().__setattr__("device_budget", device_budget)
super().__setattr__("trt_max_budget", trt_max_budget)
total_device_budget = sum(self.streamable_budget)
# Device_budget is initialized with total device budget
super().__setattr__("device_budget", total_device_budget)
super().__setattr__("total_device_budget", total_device_budget)

def get_automatic_weight_streaming_budget(self) -> int:
ws_budget_bytes = self.torch_budget
ws_budget_bytes = 0
for _, rt_mod in self.rt_mods:
ws_budget_bytes += rt_mod.get_automatic_weight_streaming_budget()
ws_budget_bytes += rt_mod.get_min_required_device_budget()
return ws_budget_bytes

def get_required_device_budgets(self) -> tuple[int, int]:
min_budget = self.torch_budget
max_budget = self.torch_budget + self.trt_max_budget
for _, rt_mod in self.rt_mods:
min_budget += rt_mod.get_min_required_device_budget()
return min_budget, max_budget

def __enter__(self) -> "_WeightStreamingContextManager":
return self

def __exit__(self, *args: Any) -> None:
max_budget = self.torch_budget + self.trt_max_budget
if self.trt_max_budget > 0:
if self.total_device_budget > 0:
logger.debug(
f"Disable weight streaming by applying max budget size {max_budget}"
f"Disable weight streaming by applying max budget size {self.total_device_budget}"
)
self.device_budget = max_budget
self.device_budget = self.total_device_budget

def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
ws_budget_bytes = self.torch_budget
trt_engine_budget = requested_budget - self.torch_budget
if self.trt_max_budget == 0:
ws_budget_bytes = 0
total_bytes = self.total_device_budget
if total_bytes == 0:
raise RuntimeError(
"Streamable bytes are zero. Was module complied with enable_weight_streaming=True option?"
)
elif trt_engine_budget <= 0:
raise RuntimeError(
f"Requested budget {requested_budget} is less than mininum torch budget: {self.torch_budget}"
elif total_bytes < requested_budget:
logger.error(
f"Requested budget is greater than streamable bytes: {total_bytes}. requested budget: {requested_budget}"
)
else:
# Normalized size is applied for multiple trt runtime module.
# e.g. 100B budget is applied to two modules and they have 1000B and 3000B max streamable size respectively.
# Then 25B and 75B are applied for each module.
for mod_name, rt_mod in self.rt_mods:
max_budget = (
rt_mod.get_min_required_device_budget()
+ rt_mod.get_streamable_weights_size()
)
normalized_size = (
int(max_budget / self.trt_max_budget * trt_engine_budget)
- rt_mod.get_min_required_device_budget()
)

if normalized_size < 0:
raise RuntimeError(
f"Requested trt budget {trt_engine_budget} is less than mininum trt budget of submodule {mod_name} size={rt_mod.get_min_required_device_budget()}"
)
requested_budget = total_bytes
elif requested_budget < 0:
raise RuntimeError("Requested budget cannot be negative")
# Normalized size is applied for multiple runtime module.
# e.g. 100B budget is applied to two modules and they have 1000B and 3000B streamable size respectively.
# Then 25B and 75B are applied for each module.
normalized_size = [
int(streamable_bytes / total_bytes * requested_budget)
for streamable_bytes in self.streamable_budget
]
for i, (name, rt_mod) in enumerate(self.rt_mods):
ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i])
logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")

ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size)
ws_budget_bytes += rt_mod.get_min_required_device_budget()
logger.debug(
f"Set weight streaming size {normalized_size} for {mod_name}"
)
return ws_budget_bytes

def __setattr__(self, name: str, value: Any) -> None:
Expand Down
44 changes: 22 additions & 22 deletions tests/py/dynamo/runtime/test_004_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,33 @@ def test_weight_streaming_manual(self, _, use_python_runtime):
)
# Weight streaming budget is applied manually.
with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx:
min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets()
streamable_budget = max_budget - min_budget
streamable_budget = weight_streaming_ctx.device_budget

requested_budget = min_budget + int(streamable_budget * 0.7)
requested_budget = int(streamable_budget * 0.7)
weight_streaming_ctx.device_budget = requested_budget
assert weight_streaming_ctx.device_budget == requested_budget

optimized_model(*input)

# Full streaming by applying min budget
weight_streaming_ctx.device_budget = min_budget
assert weight_streaming_ctx.device_budget == min_budget
# Full streaming by applying 0 budget
weight_streaming_ctx.device_budget = 0
assert weight_streaming_ctx.device_budget == 0

# Automatic weight streaming size
val = weight_streaming_ctx.get_automatic_weight_streaming_budget()
weight_streaming_ctx.device_budget = val
assert weight_streaming_ctx.device_budget > 0
requested_budget = (
weight_streaming_ctx.get_automatic_weight_streaming_budget()
)
weight_streaming_ctx.device_budget = requested_budget
assert weight_streaming_ctx.device_budget == requested_budget

requested_budget = min_budget + int(streamable_budget * 0.5)
requested_budget = int(streamable_budget * 0.5)
weight_streaming_ctx.device_budget = requested_budget
assert weight_streaming_ctx.device_budget == requested_budget

out = optimized_model(*input)

# Weight streaming is disabled after the exit from weight streaming context
assert weight_streaming_ctx.device_budget == max_budget
assert weight_streaming_ctx.device_budget == streamable_budget

ref = model(*input)
torch.testing.assert_close(
Expand Down Expand Up @@ -156,19 +157,19 @@ def test_weight_streaming_invalid_usage(self, _, use_python_runtime, multi_rt):

# Setting weight streaming context to unsupported module
with torchtrt.runtime.weight_streaming(model) as weight_streaming_ctx:
min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets()
assert min_budget == max_budget
streamable_budget = weight_streaming_ctx.device_budget
assert streamable_budget == 0

with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx:
min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets()
streamable_budget = weight_streaming_ctx.device_budget

# Values is larger than max budget disables weight streaming
weight_streaming_ctx.device_budget = max_budget + 1
assert weight_streaming_ctx.device_budget == max_budget
weight_streaming_ctx.device_budget = streamable_budget + 1
assert weight_streaming_ctx.device_budget == streamable_budget

try:
# Runtime error if requested budget is less than mininum budget
weight_streaming_ctx.device_budget = min_budget - 1
# Runtime error if requested budget is negative
weight_streaming_ctx.device_budget = -1
assert False
except RuntimeError:
assert True
Expand Down Expand Up @@ -201,18 +202,17 @@ def test_weight_streaming_multi_rt(self, _, use_python_runtime):
)

with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx:
min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets()
streamable_budget = max_budget - min_budget
streamable_budget = weight_streaming_ctx.device_budget
for pct in [0.05, 0.2, 0.4, 0.8, 1.0]:
requested_budget = min_budget + int(streamable_budget * pct)
requested_budget = int(streamable_budget * pct)
weight_streaming_ctx.device_budget = requested_budget

# Budget distribution to multiple submodule may result in integer differences of at most 1
assert abs(weight_streaming_ctx.device_budget - requested_budget) <= 1
out = optimized_model(*input)

# Weight streaming is disabled after the exit from weight streaming context
assert weight_streaming_ctx.device_budget == max_budget
assert weight_streaming_ctx.device_budget == streamable_budget

ref = model(*input)
torch.testing.assert_close(
Expand Down

0 comments on commit 892517d

Please sign in to comment.