From 892517db56152425d9fbdb3f79dc53048ebf9135 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Mon, 23 Sep 2024 05:34:39 +0000 Subject: [PATCH] chore: changed to budget range in [0, streamable budget] --- core/runtime/TRTEngine.cpp | 9 -- core/runtime/TRTEngine.h | 2 - core/runtime/register_jit_hooks.cpp | 1 - .../runtime/_PythonTorchTensorRTModule.py | 17 +--- .../dynamo/runtime/_TorchTensorRTModule.py | 12 +-- .../runtime/_weight_streaming.py | 89 +++++++------------ .../runtime/test_004_weight_streaming.py | 44 ++++----- 7 files changed, 57 insertions(+), 117 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d8c76a39cc..2facff4d52 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -92,11 +92,6 @@ TRTEngine::TRTEngine( TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); if (get_streamable_weights_size() > 0) { - // Scratch memory size may change based on the current weight streaming budget - // Required memory for full streaming is used to minimum weight budget - cuda_engine->setWeightStreamingBudgetV2(0); - min_required_device_budget = cuda_engine->getWeightStreamingScratchMemorySize(); - int64_t budget_bytes = get_weight_streaming_automatic_budget(); LOG_INFO("Set automatic weight streaming budget bytes " << budget_bytes); cuda_engine->setWeightStreamingBudgetV2(budget_bytes); @@ -297,10 +292,6 @@ int64_t TRTEngine::get_streamable_weights_size() { return cuda_engine->getStreamableWeightsSize(); } -int64_t TRTEngine::get_min_required_device_budget() { - return min_required_device_budget; -} - int64_t TRTEngine::get_weight_streaming_automatic_budget() { return cuda_engine->getWeightStreamingAutomaticBudget(); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index f10623387b..2df48e8b6b 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -74,7 +74,6 @@ struct TRTEngine : torch::CustomClassHolder { int64_t get_device_memory_budget(); bool set_device_memory_budget(int64_t budget); int64_t get_streamable_weights_size(); - int64_t get_min_required_device_budget(); int64_t get_weight_streaming_automatic_budget(); friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); static const char BINDING_DELIM = '%'; @@ -105,7 +104,6 @@ struct TRTEngine : torch::CustomClassHolder { std::string cuda_graph_debug_path; std::mutex mu; std::unique_ptr trt_engine_profiler; - int64_t min_required_device_budget; }; } // namespace runtime diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 9b7007b95a..4dfc4de552 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -91,7 +91,6 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = &TRTEngine::get_device_memory_budget, &TRTEngine::set_device_memory_budget) .def_property("streamable_weights_size", &TRTEngine::get_streamable_weights_size) - .def_property("min_required_device_budget", &TRTEngine::get_min_required_device_budget) .def_property("weight_streaming_automatic_budget", &TRTEngine::get_weight_streaming_automatic_budget) .def_pickle( [](const c10::intrusive_ptr& self) -> std::vector { diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 3fe6429686..0a43143991 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -107,7 +107,6 @@ def __init__( self.engine = None self.weight_name_map = weight_name_map self.target_platform = Platform.current_platform() - self.min_required_device_budget = 0 if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -115,12 +114,6 @@ def __init__( def get_streamable_weights_size(self) -> Any: return self.engine.streamable_weights_size - def get_min_required_device_budget(self) -> Any: - return self.min_required_device_budget - - def get_weight_streaming_budget(self) -> Any: - return self.engine.weight_streaming_budget_v2 - def get_automatic_weight_streaming_budget(self) -> Any: return self.engine.get_weight_streaming_automatic_budget() @@ -137,21 +130,15 @@ def _set_device_memory_budget(self, budget_bytes: int) -> int: if budget_bytes < 0: budget_bytes = self.get_streamable_weights_size() self.engine.weight_streaming_budget_v2 = budget_bytes - if self.get_weight_streaming_budget() != budget_bytes: + if self.engine.weight_streaming_budget_v2 != budget_bytes: logger.error(f"Failed to set weight streaming budget to {budget_bytes}") - budget_bytes = self.get_weight_streaming_budget() + budget_bytes = self.engine.weight_streaming_budget_v2 if self.engine.streamable_weights_size == budget_bytes: logger.warning("Weight streaming is disabled") return budget_bytes def set_default_streaming_budget(self) -> int: - # Scratch memory size may change based on the current weight streaming budget - # Required memory for full streaming is used to minimum weight budget - self.engine.weight_streaming_budget_v2 = 0 - self.min_required_device_budget = ( - self.engine.weight_streaming_scratch_memory_size - ) budget_bytes = self.get_automatic_weight_streaming_budget() # Set automatic weight streaming budget as default when context is created return self._set_device_memory_budget(budget_bytes) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 5ad4707a0a..c507989bd8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -172,12 +172,6 @@ def _pack_engine_info(self) -> List[str | bytes]: def get_streamable_weights_size(self) -> Any: return self.engine.streamable_weights_size - def get_min_required_device_budget(self) -> Any: - return self.engine.min_required_device_budget - - def get_weight_streaming_budget(self) -> Any: - return self.engine.device_memory_budget - def get_automatic_weight_streaming_budget(self) -> Any: return self.engine.weight_streaming_automatic_budget @@ -186,10 +180,10 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: if budget_bytes < 0: budget_bytes = self.get_streamable_weights_size() self.engine.device_memory_budget = budget_bytes - if self.get_weight_streaming_budget() != budget_bytes: + if self.engine.device_memory_budget != budget_bytes: logger.error(f"Failed to set weight streaming budget to {budget_bytes}") - budget_bytes = self.get_weight_streaming_budget() - if self.get_min_required_device_budget() == budget_bytes: + budget_bytes = self.engine.device_memory_budget + if self.get_streamable_weights_size() == budget_bytes: logger.warning("Weight streaming is disabled") return budget_bytes diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 304815b838..823c2819d8 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -14,90 +14,61 @@ class _WeightStreamingContextManager(object): def __init__(self, module: torch.fx.GraphModule) -> None: rt_mods = [] - torch_budget = 0 - trt_min_budget = 0 - trt_max_budget = 0 - current_trt_budget = 0 for name, rt_mod in module.named_children(): if "_run_on_acc" in name and isinstance( rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) ): - trt_min_budget += rt_mod.get_min_required_device_budget() - trt_max_budget += rt_mod.get_streamable_weights_size() - current_trt_budget += rt_mod.get_weight_streaming_budget() rt_mods.append((name, rt_mod)) - else: - torch_budget += sum( - [p.numel() * p.element_size() for p in rt_mod.parameters()] - ) - - trt_max_budget += trt_min_budget - self.torch_budget = torch_budget + self.streamable_budget = [ + mod.get_streamable_weights_size() for _, mod in rt_mods + ] self.rt_mods = rt_mods - device_budget = torch_budget + trt_min_budget + current_trt_budget - super().__setattr__("device_budget", device_budget) - super().__setattr__("trt_max_budget", trt_max_budget) + total_device_budget = sum(self.streamable_budget) + # Device_budget is initialized with total device budget + super().__setattr__("device_budget", total_device_budget) + super().__setattr__("total_device_budget", total_device_budget) def get_automatic_weight_streaming_budget(self) -> int: - ws_budget_bytes = self.torch_budget + ws_budget_bytes = 0 for _, rt_mod in self.rt_mods: ws_budget_bytes += rt_mod.get_automatic_weight_streaming_budget() - ws_budget_bytes += rt_mod.get_min_required_device_budget() return ws_budget_bytes - def get_required_device_budgets(self) -> tuple[int, int]: - min_budget = self.torch_budget - max_budget = self.torch_budget + self.trt_max_budget - for _, rt_mod in self.rt_mods: - min_budget += rt_mod.get_min_required_device_budget() - return min_budget, max_budget - def __enter__(self) -> "_WeightStreamingContextManager": return self def __exit__(self, *args: Any) -> None: - max_budget = self.torch_budget + self.trt_max_budget - if self.trt_max_budget > 0: + if self.total_device_budget > 0: logger.debug( - f"Disable weight streaming by applying max budget size {max_budget}" + f"Disable weight streaming by applying max budget size {self.total_device_budget}" ) - self.device_budget = max_budget + self.device_budget = self.total_device_budget def _set_streamable_weight_bytes(self, requested_budget: int) -> int: - ws_budget_bytes = self.torch_budget - trt_engine_budget = requested_budget - self.torch_budget - if self.trt_max_budget == 0: + ws_budget_bytes = 0 + total_bytes = self.total_device_budget + if total_bytes == 0: raise RuntimeError( "Streamable bytes are zero. Was module complied with enable_weight_streaming=True option?" ) - elif trt_engine_budget <= 0: - raise RuntimeError( - f"Requested budget {requested_budget} is less than mininum torch budget: {self.torch_budget}" + elif total_bytes < requested_budget: + logger.error( + f"Requested budget is greater than streamable bytes: {total_bytes}. requested budget: {requested_budget}" ) - else: - # Normalized size is applied for multiple trt runtime module. - # e.g. 100B budget is applied to two modules and they have 1000B and 3000B max streamable size respectively. - # Then 25B and 75B are applied for each module. - for mod_name, rt_mod in self.rt_mods: - max_budget = ( - rt_mod.get_min_required_device_budget() - + rt_mod.get_streamable_weights_size() - ) - normalized_size = ( - int(max_budget / self.trt_max_budget * trt_engine_budget) - - rt_mod.get_min_required_device_budget() - ) - - if normalized_size < 0: - raise RuntimeError( - f"Requested trt budget {trt_engine_budget} is less than mininum trt budget of submodule {mod_name} size={rt_mod.get_min_required_device_budget()}" - ) + requested_budget = total_bytes + elif requested_budget < 0: + raise RuntimeError("Requested budget cannot be negative") + # Normalized size is applied for multiple runtime module. + # e.g. 100B budget is applied to two modules and they have 1000B and 3000B streamable size respectively. + # Then 25B and 75B are applied for each module. + normalized_size = [ + int(streamable_bytes / total_bytes * requested_budget) + for streamable_bytes in self.streamable_budget + ] + for i, (name, rt_mod) in enumerate(self.rt_mods): + ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i]) + logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}") - ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size) - ws_budget_bytes += rt_mod.get_min_required_device_budget() - logger.debug( - f"Set weight streaming size {normalized_size} for {mod_name}" - ) return ws_budget_bytes def __setattr__(self, name: str, value: Any) -> None: diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index 4eb1c7d989..ae52a71e4d 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -89,32 +89,33 @@ def test_weight_streaming_manual(self, _, use_python_runtime): ) # Weight streaming budget is applied manually. with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx: - min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets() - streamable_budget = max_budget - min_budget + streamable_budget = weight_streaming_ctx.device_budget - requested_budget = min_budget + int(streamable_budget * 0.7) + requested_budget = int(streamable_budget * 0.7) weight_streaming_ctx.device_budget = requested_budget assert weight_streaming_ctx.device_budget == requested_budget optimized_model(*input) - # Full streaming by applying min budget - weight_streaming_ctx.device_budget = min_budget - assert weight_streaming_ctx.device_budget == min_budget + # Full streaming by applying 0 budget + weight_streaming_ctx.device_budget = 0 + assert weight_streaming_ctx.device_budget == 0 # Automatic weight streaming size - val = weight_streaming_ctx.get_automatic_weight_streaming_budget() - weight_streaming_ctx.device_budget = val - assert weight_streaming_ctx.device_budget > 0 + requested_budget = ( + weight_streaming_ctx.get_automatic_weight_streaming_budget() + ) + weight_streaming_ctx.device_budget = requested_budget + assert weight_streaming_ctx.device_budget == requested_budget - requested_budget = min_budget + int(streamable_budget * 0.5) + requested_budget = int(streamable_budget * 0.5) weight_streaming_ctx.device_budget = requested_budget assert weight_streaming_ctx.device_budget == requested_budget out = optimized_model(*input) # Weight streaming is disabled after the exit from weight streaming context - assert weight_streaming_ctx.device_budget == max_budget + assert weight_streaming_ctx.device_budget == streamable_budget ref = model(*input) torch.testing.assert_close( @@ -156,19 +157,19 @@ def test_weight_streaming_invalid_usage(self, _, use_python_runtime, multi_rt): # Setting weight streaming context to unsupported module with torchtrt.runtime.weight_streaming(model) as weight_streaming_ctx: - min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets() - assert min_budget == max_budget + streamable_budget = weight_streaming_ctx.device_budget + assert streamable_budget == 0 with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx: - min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets() + streamable_budget = weight_streaming_ctx.device_budget # Values is larger than max budget disables weight streaming - weight_streaming_ctx.device_budget = max_budget + 1 - assert weight_streaming_ctx.device_budget == max_budget + weight_streaming_ctx.device_budget = streamable_budget + 1 + assert weight_streaming_ctx.device_budget == streamable_budget try: - # Runtime error if requested budget is less than mininum budget - weight_streaming_ctx.device_budget = min_budget - 1 + # Runtime error if requested budget is negative + weight_streaming_ctx.device_budget = -1 assert False except RuntimeError: assert True @@ -201,10 +202,9 @@ def test_weight_streaming_multi_rt(self, _, use_python_runtime): ) with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx: - min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets() - streamable_budget = max_budget - min_budget + streamable_budget = weight_streaming_ctx.device_budget for pct in [0.05, 0.2, 0.4, 0.8, 1.0]: - requested_budget = min_budget + int(streamable_budget * pct) + requested_budget = int(streamable_budget * pct) weight_streaming_ctx.device_budget = requested_budget # Budget distribution to multiple submodule may result in integer differences of at most 1 @@ -212,7 +212,7 @@ def test_weight_streaming_multi_rt(self, _, use_python_runtime): out = optimized_model(*input) # Weight streaming is disabled after the exit from weight streaming context - assert weight_streaming_ctx.device_budget == max_budget + assert weight_streaming_ctx.device_budget == streamable_budget ref = model(*input) torch.testing.assert_close(