From f93a732bc99b51a1e07e0042fdd9cf9ff38ad2b7 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Thu, 7 Dec 2023 16:17:34 -0800 Subject: [PATCH 1/2] fix/feat: Add support for multiple TRT Build Args (#2510) --- py/torch_tensorrt/dynamo/_compiler.py | 30 ++++-- py/torch_tensorrt/dynamo/_defaults.py | 9 ++ py/torch_tensorrt/dynamo/_settings.py | 25 +++++ .../dynamo/conversion/_TRTInterpreter.py | 78 +++++++++------ .../dynamo/conversion/_conversion.py | 13 +-- tests/py/dynamo/conversion/harness.py | 8 +- .../runtime/test_compilation_settings.py | 95 +++++++++++++++++++ 7 files changed, 204 insertions(+), 54 deletions(-) create mode 100644 tests/py/dynamo/runtime/test_compilation_settings.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 8d0092bfd6..d91ddab15f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -16,13 +16,21 @@ from torch_tensorrt.dynamo._defaults import ( DEBUG, DEVICE, + DISABLE_TF32, + DLA_GLOBAL_DRAM_SIZE, + DLA_LOCAL_DRAM_SIZE, + DLA_SRAM_SIZE, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + ENGINE_CAPABILITY, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, + NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, PRECISION, + REFIT, REQUIRE_FULL_COMPILATION, + SPARSE_WEIGHTS, TRUNCATE_LONG_AND_DOUBLE, USE_FAST_PARTITIONER, USE_PYTHON_RUNTIME, @@ -51,17 +59,18 @@ def compile( inputs: Any, *, device: Optional[Union[Device, torch.device, str]] = DEVICE, - disable_tf32: bool = False, - sparse_weights: bool = False, + disable_tf32: bool = DISABLE_TF32, + sparse_weights: bool = SPARSE_WEIGHTS, enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,), - refit: bool = False, + engine_capability: EngineCapability = ENGINE_CAPABILITY, + refit: bool = REFIT, debug: bool = DEBUG, capability: EngineCapability = EngineCapability.default, - num_avg_timing_iters: int = 1, + num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, workspace_size: int = WORKSPACE_SIZE, - dla_sram_size: int = 1048576, - dla_local_dram_size: int = 1073741824, - dla_global_dram_size: int = 536870912, + dla_sram_size: int = DLA_SRAM_SIZE, + dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, calibrator: object = None, truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, @@ -200,6 +209,13 @@ def compile( "use_fast_partitioner": use_fast_partitioner, "enable_experimental_decompositions": enable_experimental_decompositions, "require_full_compilation": require_full_compilation, + "disable_tf32": disable_tf32, + "sparse_weights": sparse_weights, + "refit": refit, + "engine_capability": engine_capability, + "dla_sram_size": dla_sram_size, + "dla_local_dram_size": dla_local_dram_size, + "dla_global_dram_size": dla_global_dram_size, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 103b5f7792..4ec872fb1b 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,19 +1,28 @@ import torch +from tensorrt import EngineCapability from torch_tensorrt._Device import Device PRECISION = torch.float32 DEBUG = False DEVICE = None +DISABLE_TF32 = False +DLA_LOCAL_DRAM_SIZE = 1073741824 +DLA_GLOBAL_DRAM_SIZE = 536870912 +DLA_SRAM_SIZE = 1048576 +ENGINE_CAPABILITY = EngineCapability.STANDARD WORKSPACE_SIZE = 0 MIN_BLOCK_SIZE = 5 PASS_THROUGH_BUILD_FAILURES = False MAX_AUX_STREAMS = None +NUM_AVG_TIMING_ITERS = 1 VERSION_COMPATIBLE = False OPTIMIZATION_LEVEL = None +SPARSE_WEIGHTS = False TRUNCATE_LONG_AND_DOUBLE = False USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False +REFIT = False REQUIRE_FULL_COMPILATION = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index c9f4534cb8..cd58c9547f 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -2,16 +2,25 @@ from typing import Optional, Set import torch +from tensorrt import EngineCapability from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._defaults import ( DEBUG, + DISABLE_TF32, + DLA_GLOBAL_DRAM_SIZE, + DLA_LOCAL_DRAM_SIZE, + DLA_SRAM_SIZE, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + ENGINE_CAPABILITY, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, + NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, PRECISION, + REFIT, REQUIRE_FULL_COMPILATION, + SPARSE_WEIGHTS, TRUNCATE_LONG_AND_DOUBLE, USE_FAST_PARTITIONER, USE_PYTHON_RUNTIME, @@ -46,6 +55,14 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path + disable_tf32 (bool): Whether to disable TF32 computation for TRT layers + sparse_weights (bool): Whether to allow the builder to use sparse weights + refit (bool): Whether to build a refittable engine + engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels + dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. + dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations + dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution """ precision: torch.dtype = PRECISION @@ -63,3 +80,11 @@ class CompilationSettings: enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION + disable_tf32: bool = DISABLE_TF32 + sparse_weights: bool = SPARSE_WEIGHTS + refit: bool = REFIT + engine_capability: EngineCapability = ENGINE_CAPABILITY + num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS + dla_sram_size: int = DLA_SRAM_SIZE + dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE + dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index d9d5229901..eec7e62516 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,8 +4,6 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np - -# @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch import torch.fx @@ -97,6 +95,7 @@ def __init__( self._itensor_to_tensor_meta: Dict[ trt.tensorrt.ITensor, TensorMetadata ] = dict() + self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors self.output_dtypes = output_dtypes @@ -119,40 +118,25 @@ def validate_conversion(self) -> Set[str]: def run( self, - workspace_size: int = 0, - precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set - sparse_weights: bool = False, - disable_tf32: bool = False, force_fp32_output: bool = False, strict_type_constraints: bool = False, algorithm_selector: Optional[trt.IAlgorithmSelector] = None, timing_cache: Optional[trt.ITimingCache] = None, - profiling_verbosity: Optional[trt.ProfilingVerbosity] = None, tactic_sources: Optional[int] = None, - max_aux_streams: Optional[int] = None, - version_compatible: bool = False, - optimization_level: Optional[int] = None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. Args: - workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation. - precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). - sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity force_fp32_output: force output to be fp32 strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. algorithm_selector: set up algorithm selection for certain layer timing_cache: enable timing cache for TensorRT - profiling_verbosity: TensorRT logging level - max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine - version_compatible: Provide version forward-compatibility for engine plan files - optimization_level: Builder optimization 0-5, higher levels imply longer build time, - searching for more optimization options. TRT defaults to 3 Return: TRTInterpreterResult """ TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + precision = self.compilation_settings.precision # For float outputs, we set their dtype to fp16 only if precision == torch.float16 and # force_fp32_output=False. Overriden by specifying output_dtypes self.output_fp16 = not force_fp32_output and precision == torch.float16 @@ -173,9 +157,9 @@ def run( builder_config = self.builder.create_builder_config() - if workspace_size != 0: + if self.compilation_settings.workspace_size != 0: builder_config.set_memory_pool_limit( - trt.MemoryPoolType.WORKSPACE, workspace_size + trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size ) cache = None @@ -188,21 +172,50 @@ def run( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( - profiling_verbosity - if profiling_verbosity + trt.ProfilingVerbosity.VERBOSE + if self.compilation_settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) if version.parse(trt.__version__) >= version.parse("8.6"): - if max_aux_streams is not None: - _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") - builder_config.max_aux_streams = max_aux_streams - if version_compatible: + if self.compilation_settings.max_aux_streams is not None: + _LOGGER.info( + f"Setting max aux streams to {self.compilation_settings.max_aux_streams}" + ) + builder_config.max_aux_streams = ( + self.compilation_settings.max_aux_streams + ) + if self.compilation_settings.version_compatible: _LOGGER.info("Using version compatible") builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) - if optimization_level is not None: - _LOGGER.info(f"Using optimization level {optimization_level}") - builder_config.builder_optimization_level = optimization_level + if self.compilation_settings.optimization_level is not None: + _LOGGER.info( + f"Using optimization level {self.compilation_settings.optimization_level}" + ) + builder_config.builder_optimization_level = ( + self.compilation_settings.optimization_level + ) + + builder_config.engine_capability = self.compilation_settings.engine_capability + builder_config.avg_timing_iterations = ( + self.compilation_settings.num_avg_timing_iters + ) + + if self.compilation_settings.device.device_type == trt.DeviceType.DLA: + builder_config.DLA_core = self.compilation_settings.device.dla_core + _LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}") + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_MANAGED_SRAM, + self.compilation_settings.dla_sram_size, + ) + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_LOCAL_DRAM, + self.compilation_settings.dla_local_dram_size, + ) + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_GLOBAL_DRAM, + self.compilation_settings.dla_global_dram_size, + ) if precision == torch.float16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -210,12 +223,15 @@ def run( if precision == torch.int8: builder_config.set_flag(trt.BuilderFlag.INT8) - if sparse_weights: + if self.compilation_settings.sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) - if disable_tf32: + if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) + if self.compilation_settings.refit: + builder_config.set_flag(trt.BuilderFlag.REFIT) + if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index c738b18bc7..d39b7f35c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -53,18 +53,7 @@ def convert_module( output_dtypes=output_dtypes, compilation_settings=settings, ) - interpreter_result = interpreter.run( - workspace_size=settings.workspace_size, - precision=settings.precision, - profiling_verbosity=( - trt.ProfilingVerbosity.VERBOSE - if settings.debug - else trt.ProfilingVerbosity.LAYER_NAMES_ONLY - ), - max_aux_streams=settings.max_aux_streams, - version_compatible=settings.version_compatible, - optimization_level=settings.optimization_level, - ) + interpreter_result = interpreter.run() if settings.use_python_runtime: return PythonTorchTensorRTModule( diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index be13f7d2c1..404f50a187 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -50,7 +50,6 @@ def run_test( interpreter, rtol, atol, - precision=torch.float, check_dtype=True, ): with torch.no_grad(): @@ -60,7 +59,7 @@ def run_test( mod.eval() start = time.perf_counter() - interpreter_result = interpreter.run(precision=precision) + interpreter_result = interpreter.run() sec = time.perf_counter() - start _LOGGER.info(f"Interpreter run time(s): {sec}") trt_mod = PythonTorchTensorRTModule( @@ -234,7 +233,9 @@ def run_test( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here - compilation_settings = CompilationSettings(truncate_long_and_double=True) + compilation_settings = CompilationSettings( + precision=precision, truncate_long_and_double=True + ) interp = TRTInterpreter( mod, @@ -248,7 +249,6 @@ def run_test( interp, rtol, atol, - precision, check_dtype, ) diff --git a/tests/py/dynamo/runtime/test_compilation_settings.py b/tests/py/dynamo/runtime/test_compilation_settings.py new file mode 100644 index 0000000000..daa67ad032 --- /dev/null +++ b/tests/py/dynamo/runtime/test_compilation_settings.py @@ -0,0 +1,95 @@ +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + + +class TestEnableTRTFlags(TestCase): + def test_toggle_build_args(self): + class AddSoftmax(torch.nn.Module): + def forward(self, x): + x = 3 * x + y = x + 1 + return torch.softmax(y, 0) + + inputs = [ + torch.rand( + 3, + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(AddSoftmax()) + + # Validate that the results between Torch and Torch-TRT are similar + # Enable multiple TRT build arguments + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + disable_tf32=True, + sparse_weights=True, + refit=True, + num_avg_timing_iters=5, + workspace_size=1 << 10, + truncate_long_and_double=True, + ) + + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"AddSoftmax TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_dla_args(self): + class AddSoftmax(torch.nn.Module): + def forward(self, x): + x = 3 * x + y = x + 1 + return torch.softmax(y, 0) + + inputs = [ + torch.rand( + 3, + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(AddSoftmax()) + + # Validate that the results between Torch and Torch-TRT are similar + # Enable multiple TRT build arguments + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + device=torch_tensorrt.Device("dla:0", allow_gpu_fallback=True), + pass_through_build_failures=True, + dla_sram_size=1048577, + dla_local_dram_size=1073741825, + dla_global_dram_size=536870913, + ) + + # DLA is not present on the active machine + with self.assertRaises(RuntimeError): + optimized_model(*inputs).detach().cpu() + + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests() From 80743b09e68d5e003b061accf6e6f87c79d98e13 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Thu, 7 Dec 2023 16:18:00 -0800 Subject: [PATCH 2/2] feat: Safety Mode for Runtime (#2512) --- core/runtime/TRTEngine.cpp | 1 + core/runtime/execute_engine.cpp | 2 +- core/runtime/register_jit_hooks.cpp | 4 + core/runtime/runtime.cpp | 17 ++- core/runtime/runtime.h | 3 + docsrc/user_guide/runtime.rst | 34 +++++ py/torch_tensorrt/__init__.py | 8 +- .../dynamo/conversion/_conversion.py | 2 + .../partitioning/_adjacency_partitioner.py | 4 +- .../partitioning/_global_partitioner.py | 4 +- .../runtime/_PythonTorchTensorRTModule.py | 66 ++++++++- py/torch_tensorrt/dynamo/runtime/tools.py | 131 ++++++++++++++++++ py/torch_tensorrt/runtime/__init__.py | 1 + .../runtime/multi_device_safe_mode.py | 51 +++++++ setup.py | 2 + tests/py/dynamo/runtime/test_safe_mode.py | 105 ++++++++++++++ 16 files changed, 420 insertions(+), 15 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/runtime/tools.py create mode 100644 py/torch_tensorrt/runtime/__init__.py create mode 100644 py/torch_tensorrt/runtime/multi_device_safe_mode.py create mode 100644 tests/py/dynamo/runtime/test_safe_mode.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 1bca7869e3..13cbe3a126 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -52,6 +52,7 @@ TRTEngine::TRTEngine( auto most_compatible_device = get_most_compatible_device(cuda_device); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); device_info = most_compatible_device.value(); + multi_gpu_device_check(); set_rt_device(device_info); rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 2a7fe884da..5551010a2a 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -74,7 +74,7 @@ std::vector execute_engine(std::vector inputs, c10::intr LOG_INFO("" << log_info); } - { + if (MULTI_DEVICE_SAFE_MODE) { std::unique_ptr device_profiler_guard; if (compiled_engine->profile_execution) { device_profiler_guard = diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index c5b9118fee..1acc27dda5 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -114,6 +114,10 @@ TORCH_LIBRARY(tensorrt, m) { m.def("execute_engine", execute_engine); m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); }); m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; }); + m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; }); + m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { + MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; + }); } } // namespace diff --git a/core/runtime/runtime.cpp b/core/runtime/runtime.cpp index 0372258919..2d7f7f1198 100644 --- a/core/runtime/runtime.cpp +++ b/core/runtime/runtime.cpp @@ -7,6 +7,8 @@ namespace torch_tensorrt { namespace core { namespace runtime { +bool MULTI_DEVICE_SAFE_MODE = false; + c10::optional get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) { LOG_DEBUG("Target Device: " << target_device); auto device_options = find_compatible_devices(target_device); @@ -31,13 +33,13 @@ c10::optional get_most_compatible_device(const RTDevice& target_device if (device.device_name == target_device.device_name) { // First priority is selecting a candidate which agrees with the current device ID // If such a device is found, we can select it and break out of the loop - if (device.id == current_device.id && best_match.id != current_device.id) { + if (device.id == current_device.id) { best_match = device; break; } // Second priority is selecting a candidate which agrees with the target device ID // At deserialization time, the current device and target device may not agree - else if (device.id == target_device.id && best_match.id != target_device.id) { + else if (device.id == target_device.id) { best_match = device; } // If no such GPU ID is found, select the first available candidate GPU @@ -103,6 +105,17 @@ RTDevice get_current_device() { return RTDevice(device_id, nvinfer1::DeviceType::kGPU); } +void multi_gpu_device_check() { + // If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user + if (!(MULTI_DEVICE_SAFE_MODE) && get_available_device_list().get_devices().size() > 1) { + LOG_WARNING( + "Detected this engine is being instantitated in a multi-GPU system with " + << "multi-device safe mode disabled. For more on the implications of this " + << "as well as workarounds, see the linked documentation " + << "(https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode)"); + } +} + namespace { static DeviceList cuda_device_list; } diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 05d97a30b8..ea863850ba 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -16,6 +16,7 @@ namespace runtime { using EngineID = int64_t; const std::string ABI_VERSION = "4"; +extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { ABI_TARGET_IDX = 0, NAME_IDX, @@ -33,6 +34,8 @@ std::vector find_compatible_devices(const RTDevice& target_device); std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine); +void multi_gpu_device_check(); + class DeviceList { using DeviceMap = std::unordered_map; DeviceMap device_list; diff --git a/docsrc/user_guide/runtime.rst b/docsrc/user_guide/runtime.rst index 0cfc93200f..8264abdd32 100644 --- a/docsrc/user_guide/runtime.rst +++ b/docsrc/user_guide/runtime.rst @@ -34,3 +34,37 @@ Plugin Library In the case you use Torch-TensorRT as a converter to a TensorRT engine and your engine uses plugins provided by Torch-TensorRT, Torch-TensorRT ships the library ``libtorchtrt_plugins.so`` which contains the implementation of the TensorRT plugins used by Torch-TensorRT during compilation. This library can be ``DL_OPEN`` or ``LD_PRELOAD`` similar to other TensorRT plugin libraries. + +Multi Device Safe Mode +--------------- + +Multi-device safe mode is a setting in Torch-TensorRT which allows the user to determine whether +the runtime checks for device consistency prior to every inference call. + +There is a non-negligible, fixed cost per-inference call when multi-device safe mode is enabled, which is why +it is now disabled by default. It can be controlled via the following convenience function which +doubles as a context manager. + +.. code-block:: python + + # Enables Multi Device Safe Mode + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + + # Disables Multi Device Safe Mode [Default Behavior] + torch_tensorrt.runtime.set_multi_device_safe_mode(False) + + # Enables Multi Device Safe Mode, then resets the safe mode to its prior setting + with torch_tensorrt.runtime.set_multi_device_safe_mode(True): + ... + +TensorRT requires that each engine be associated with the CUDA context in the active thread from which it is invoked. +Therefore, if the device were to change in the active thread, which may be the case when invoking +engines on multiple GPUs from the same Python process, safe mode will cause Torch-TensorRT to display +an alert and switch GPUs accordingly. If safe mode were not enabled, there could be a mismatch in the engine +device and CUDA context device, which could lead the program to crash. + +One technique for managing multiple TRT engines on different GPUs while not sacrificing performance for +multi-device safe mode is to use Python threads. Each thread is responsible for all of the TRT engines +on a single GPU, and the default CUDA device on each thread corresponds to the GPU for which it is +responsible (can be set via ``torch.cuda.set_device(...)``). In this way, multiple threads can be used in the same +Python script without needing to switch CUDA contexts and incur performance overhead. diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index c015bd89db..b9d2af39c5 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -85,15 +85,17 @@ def _find_lib(name: str, paths: List[str]) -> str: from torch_tensorrt._Device import Device # noqa: F401 from torch_tensorrt._enums import * # noqa: F403 from torch_tensorrt._Input import Input # noqa: F401 -from torch_tensorrt.logging import * -from torch_tensorrt.ptq import * from torch_tensorrt._utils import * # noqa: F403 from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt.logging import * +from torch_tensorrt.ptq import * +from torch_tensorrt.runtime import * # noqa: F403 if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from torch_tensorrt import dynamo # noqa: F401 from torch_tensorrt.dynamo import backend # noqa: F401 + from torch_tensorrt import dynamo # noqa: F401 + def _register_with_torch() -> None: trtorch_dir = os.path.dirname(__file__) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index d39b7f35c7..e4f0df5818 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -60,6 +60,8 @@ def convert_module( engine=interpreter_result.engine, input_names=list(interpreter_result.input_names), output_names=list(interpreter_result.output_names), + target_device=settings.device, + profiling_enabled=settings.debug, ) else: diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 55df3cb2b3..5bdbb8919b 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -42,10 +42,10 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + node in CONVERTERS or node.op == "get_attr" ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator - if not node.is_impure(): + if not node.is_impure() and node.op != "get_attr": if node_name not in self.supported_operators: self.supported_operators[node_name] = 1 else: diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index f6149a2271..092bdabfd0 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -150,10 +150,10 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + node in CONVERTERS or node.op == "get_attr" ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator - if not node.is_impure(): + if not node.is_impure() and node.op != "get_attr": if node_name not in self.supported_operators: self.supported_operators[node_name] = 1 else: diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 41baecc7ab..db45609123 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -1,13 +1,22 @@ from __future__ import annotations import logging +from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple import tensorrt as trt import torch from torch.nn import Module +from torch_tensorrt._Device import Device +from torch_tensorrt.dynamo.runtime.tools import ( + _is_switch_required, + _select_rt_device, + multi_gpu_device_check, +) from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter +import torch_tensorrt + logger = logging.getLogger(__name__) @@ -23,13 +32,26 @@ def __init__( engine: trt.ICudaEngine, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, + target_device: Device = Device._current_device(), + profiling_enabled: Optional[bool] = None, ): super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) + + # Run multi-gpu device check to validate engine instantiation + multi_gpu_device_check() + self.engine = engine self.input_names = input_names if input_names is not None else [] self.output_names = output_names if output_names is not None else [] self.initialized = False + self.target_device_id = target_device.gpu_id + self.target_device_properties = torch.cuda.get_device_properties( + self.target_device_id + ) + self.profiling_enabled = ( + profiling_enabled if profiling_enabled is not None else False + ) self._initialize() def _initialize(self) -> None: @@ -119,6 +141,9 @@ def _load_from_state_dict( ) -> None: engine_bytes = state_dict[prefix + "engine"] + # Run multi-gpu device check to validate engine instantiation + multi_gpu_device_check() + logger = trt.Logger() runtime = trt.Runtime(logger) self.engine = runtime.deserialize_cuda_engine(engine_bytes) @@ -141,15 +166,43 @@ def __setstate__(self, state: Dict[str, Any]) -> None: if self.engine: self.context = self.engine.create_execution_context() - def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:Forward" - ): + ) if self.profiling_enabled else nullcontext(): self._check_initialized() + # If in safe mode, check at each iteration for for whether a switch is required + if ( + torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE + ): + curr_device_id = torch.cuda.current_device() + curr_device_properties = torch.cuda.get_device_properties( + curr_device_id + ) + logger.debug(f"Current Device: cuda:{curr_device_id}") + + # If a switch is required, move all inputs to new device and set as active device + if _is_switch_required( + curr_device_id, + self.target_device_id, + curr_device_properties, + self.target_device_properties, + ): + device_id, _ = _select_rt_device( + curr_device_id, + self.target_device_id, + self.target_device_properties, + ) + device = torch.device(device_id) + torch.cuda.set_device(device_id) + + inputs = tuple([tensor.to(device) for tensor in inputs]) + logger.warning(f"Moved all input Tensors to cuda:{device_id}") + with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessInputs" - ): + ) if self.profiling_enabled else nullcontext(): assert len(inputs) == len( self.input_names ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." @@ -188,7 +241,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessOutputs" - ): + ) if self.profiling_enabled else nullcontext(): # create output tensors outputs: List[torch.Tensor] = [] @@ -215,7 +268,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:TensorRTRuntime" - ): + ) if self.profiling_enabled else nullcontext(): self.context.execute_async_v2( bindings, torch.cuda.current_stream().cuda_stream ) @@ -235,6 +288,8 @@ def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: if not self.context.profiler: self.context.profiler = trt.Profiler() if profiler is None else profiler + self.profiling_enabled = True + def disable_profiling(self) -> None: """ Disable TensorRT profiling. @@ -244,6 +299,7 @@ def disable_profiling(self) -> None: torch.cuda.synchronize() del self.context self.context = self.engine.create_execution_context() + self.profiling_enabled = False def get_layer_info(self) -> str: """ diff --git a/py/torch_tensorrt/dynamo/runtime/tools.py b/py/torch_tensorrt/dynamo/runtime/tools.py new file mode 100644 index 0000000000..75c83a4f60 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/tools.py @@ -0,0 +1,131 @@ +import logging +from typing import Optional, Tuple + +import torch + +import torch_tensorrt + +logger = logging.getLogger(__name__) + + +def multi_gpu_device_check() -> None: + # If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user + if ( + not torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE + and torch.cuda.device_count() > 1 + ): + logger.warning( + "Detected this engine is being instantitated in a multi-GPU system with " + "multi-device safe mode disabled. For more on the implications of this " + "as well as workarounds, see the linked documentation " + "(https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode). " + f"The engine is set to be instantiated on the current default cuda device, cuda:{torch.cuda.current_device()}. " + "If this is incorrect, please set the desired cuda device via torch.cuda.set_device(...) and retry." + ) + + +def _is_switch_required( + curr_device_id: int, + engine_device_id: int, + curr_device_properties: torch._C._CudaDeviceProperties, + engine_device_properties: torch._C._CudaDeviceProperties, +) -> bool: + """Determines whether a device switch is required based on input device parameters""" + # Device Capabilities disagree + if (curr_device_properties.major, curr_device_properties.minor) != ( + engine_device_properties.major, + engine_device_properties.minor, + ): + logger.warning( + f"Configured SM capability {(engine_device_properties.major, engine_device_properties.minor)} does not match with " + f"current device SM capability {(curr_device_properties.major, curr_device_properties.minor)}. Switching device context." + ) + + return True + + # Names disagree + if curr_device_properties.name != engine_device_properties.name: + logger.warning( + f"Program compiled for {engine_device_properties.name} but current CUDA device is " + f"current device SM capability {curr_device_properties.name}. Attempting to switch device context for better compatibility." + ) + + return True + + # Device IDs disagree + if curr_device_id != engine_device_id: + logger.warning( + f"Configured Device ID: {engine_device_id} is different than current device ID: " + f"{curr_device_id}. Attempting to switch device context for better compatibility." + ) + + return True + + return False + + +def _select_rt_device( + curr_device_id: int, + engine_device_id: int, + engine_device_properties: torch._C._CudaDeviceProperties, +) -> Tuple[int, torch._C._CudaDeviceProperties]: + """Wraps compatible device check and raises error if none are found""" + new_target_device_opt = _get_most_compatible_device( + curr_device_id, engine_device_id, engine_device_properties + ) + + assert ( + new_target_device_opt is not None + ), "Could not find a compatible device on the system to run TRT Engine" + + return new_target_device_opt + + +def _get_most_compatible_device( + curr_device_id: int, + engine_device_id: int, + engine_device_properties: torch._C._CudaDeviceProperties, +) -> Optional[Tuple[int, torch._C._CudaDeviceProperties]]: + """Selects a runtime device based on compatibility checks""" + all_devices = [ + (i, torch.cuda.get_device_properties(i)) + for i in range(torch.cuda.device_count()) + ] + logger.debug(f"All available devices: {all_devices}") + target_device_sm = (engine_device_properties.major, engine_device_properties.minor) + + # Any devices with the same SM capability are valid candidates + candidate_devices = [ + (i, device_properties) + for i, device_properties in all_devices + if (device_properties.major, device_properties.minor) == target_device_sm + ] + + logger.debug(f"Found candidate devices: {candidate_devices}") + + # If less than 2 candidates are found, return + if len(candidate_devices) <= 1: + return candidate_devices[0] if candidate_devices else None + + # If more than 2 candidates are found, select the best match + best_match = None + + for candidate in candidate_devices: + i, device_properties = candidate + # First priority is selecting a candidate which agrees with the current device ID + # If such a device is found, we can select it and break out of the loop + if device_properties.name == engine_device_properties.name: + if i == curr_device_id: + best_match = candidate + break + + # Second priority is selecting a candidate which agrees with the target device ID + # At deserialization time, the current device and target device may not agree + elif i == engine_device_id: + best_match = candidate + + # If no such GPU ID is found, select the first available candidate GPU + elif best_match is None: + best_match = candidate + + return best_match diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py new file mode 100644 index 0000000000..29895c83d5 --- /dev/null +++ b/py/torch_tensorrt/runtime/__init__.py @@ -0,0 +1 @@ +from .multi_device_safe_mode import set_multi_device_safe_mode diff --git a/py/torch_tensorrt/runtime/multi_device_safe_mode.py b/py/torch_tensorrt/runtime/multi_device_safe_mode.py new file mode 100644 index 0000000000..0ddd900ab6 --- /dev/null +++ b/py/torch_tensorrt/runtime/multi_device_safe_mode.py @@ -0,0 +1,51 @@ +import logging +from importlib.util import find_spec +from typing import Any + +import torch + +if find_spec("torch_tensorrt._C") is not None: + _PY_RT_MULTI_DEVICE_SAFE_MODE = torch.ops.tensorrt.get_multi_device_safe_mode() +else: + _PY_RT_MULTI_DEVICE_SAFE_MODE = False + + +logger = logging.getLogger(__name__) + + +class _MultiDeviceSafeModeContextManager(object): + """Helper class used in conjunction with `set_multi_device_safe_mode` + + Used to enable `set_multi_device_safe_mode` as a dual-purpose context manager + """ + + def __init__(self, old_mode: bool) -> None: + self.old_mode = old_mode + + def __enter__(self) -> "_MultiDeviceSafeModeContextManager": + return self + + def __exit__(self, *args: Any) -> None: + # Set multi-device safe mode back to old mode in Python + global _PY_RT_MULTI_DEVICE_SAFE_MODE + _PY_RT_MULTI_DEVICE_SAFE_MODE = self.old_mode + + # Set multi-device safe mode back to old mode in C++ + if find_spec("torch_tensorrt._C") is not None: + torch.ops.tensorrt.set_multi_device_safe_mode(self.old_mode) + + +def set_multi_device_safe_mode(mode: bool) -> _MultiDeviceSafeModeContextManager: + # Fetch existing safe mode and set new mode for Python + global _PY_RT_MULTI_DEVICE_SAFE_MODE + old_mode = _PY_RT_MULTI_DEVICE_SAFE_MODE + _PY_RT_MULTI_DEVICE_SAFE_MODE = mode + + # Set new mode for C++ + if find_spec("torch_tensorrt._C") is not None: + torch.ops.tensorrt.set_multi_device_safe_mode(mode) + + logger.info(f"Set multi-device safe mode to {mode}") + + # Return context manager in case the function is used in a `with` call + return _MultiDeviceSafeModeContextManager(old_mode) diff --git a/setup.py b/setup.py index 82f1ac42f7..38d2121461 100644 --- a/setup.py +++ b/setup.py @@ -403,6 +403,7 @@ def run(self): "torch_tensorrt.fx.tracer", "torch_tensorrt.fx.tracer.acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer", + "torch_tensorrt.runtime", ] package_dir = { @@ -430,6 +431,7 @@ def run(self): "torch_tensorrt.fx.tracer": "py/torch_tensorrt/fx/tracer", "torch_tensorrt.fx.tracer.acc_tracer": "py/torch_tensorrt/fx/tracer/acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer": "py/torch_tensorrt/fx/tracer/dispatch_tracer", + "torch_tensorrt.runtime": "py/torch_tensorrt/runtime", } package_data = {} diff --git a/tests/py/dynamo/runtime/test_safe_mode.py b/tests/py/dynamo/runtime/test_safe_mode.py new file mode 100644 index 0000000000..bd196b12f0 --- /dev/null +++ b/tests/py/dynamo/runtime/test_safe_mode.py @@ -0,0 +1,105 @@ +import torch +from torch.testing._internal.common_utils import TestCase, run_tests + +import torch_tensorrt + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + + +class TestSafeMode(TestCase): + def test_multi_device_safe_mode_on(self): + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + self.assertTrue(torch.ops.tensorrt.get_multi_device_safe_mode()) + + def test_multi_device_safe_mode_off(self): + torch_tensorrt.runtime.set_multi_device_safe_mode(False) + self.assertFalse(torch.ops.tensorrt.get_multi_device_safe_mode()) + + def test_multi_device_safe_mode_context(self): + with torch_tensorrt.runtime.set_multi_device_safe_mode(True): + self.assertTrue(torch.ops.tensorrt.get_multi_device_safe_mode()) + self.assertFalse(torch.ops.tensorrt.get_multi_device_safe_mode()) + + def test_multi_device_safe_mode_enabled_inference_python(self): + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [ + torch.randn( + 3, + 5, + 7, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Safe Mode Python TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_multi_device_safe_mode_enabled_inference_cpp(self): + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [ + torch.randn( + 3, + 5, + 7, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=False, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Safe Mode C++ TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests()