Skip to content

Commit

Permalink
Merge pull request #2521 from pytorch/cherry_pick_safe_mode_build_args
Browse files Browse the repository at this point in the history
cherry-pick: Safe mode and Build Arguments PRs
  • Loading branch information
gs-olive authored Dec 20, 2023
2 parents fdd6bad + 80743b0 commit b6dd22b
Show file tree
Hide file tree
Showing 22 changed files with 624 additions and 69 deletions.
1 change: 1 addition & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ TRTEngine::TRTEngine(
auto most_compatible_device = get_most_compatible_device(cuda_device);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
device_info = most_compatible_device.value();
multi_gpu_device_check();
set_rt_device(device_info);

rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
LOG_INFO("" << log_info);
}

{
if (MULTI_DEVICE_SAFE_MODE) {
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
if (compiled_engine->profile_execution) {
device_profiler_guard =
Expand Down
4 changes: 4 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("execute_engine", execute_engine);
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; });
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
}

} // namespace
Expand Down
17 changes: 15 additions & 2 deletions core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

bool MULTI_DEVICE_SAFE_MODE = false;

c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
LOG_DEBUG("Target Device: " << target_device);
auto device_options = find_compatible_devices(target_device);
Expand All @@ -31,13 +33,13 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
if (device.device_name == target_device.device_name) {
// First priority is selecting a candidate which agrees with the current device ID
// If such a device is found, we can select it and break out of the loop
if (device.id == current_device.id && best_match.id != current_device.id) {
if (device.id == current_device.id) {
best_match = device;
break;
}
// Second priority is selecting a candidate which agrees with the target device ID
// At deserialization time, the current device and target device may not agree
else if (device.id == target_device.id && best_match.id != target_device.id) {
else if (device.id == target_device.id) {
best_match = device;
}
// If no such GPU ID is found, select the first available candidate GPU
Expand Down Expand Up @@ -103,6 +105,17 @@ RTDevice get_current_device() {
return RTDevice(device_id, nvinfer1::DeviceType::kGPU);
}

void multi_gpu_device_check() {
// If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user
if (!(MULTI_DEVICE_SAFE_MODE) && get_available_device_list().get_devices().size() > 1) {
LOG_WARNING(
"Detected this engine is being instantitated in a multi-GPU system with "
<< "multi-device safe mode disabled. For more on the implications of this "
<< "as well as workarounds, see the linked documentation "
<< "(https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode)");
}
}

namespace {
static DeviceList cuda_device_list;
}
Expand Down
3 changes: 3 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace runtime {

using EngineID = int64_t;
const std::string ABI_VERSION = "4";
extern bool MULTI_DEVICE_SAFE_MODE;
typedef enum {
ABI_TARGET_IDX = 0,
NAME_IDX,
Expand All @@ -33,6 +34,8 @@ std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);

void multi_gpu_device_check();

class DeviceList {
using DeviceMap = std::unordered_map<int, RTDevice>;
DeviceMap device_list;
Expand Down
34 changes: 34 additions & 0 deletions docsrc/user_guide/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,37 @@ Plugin Library
In the case you use Torch-TensorRT as a converter to a TensorRT engine and your engine uses plugins provided by Torch-TensorRT, Torch-TensorRT
ships the library ``libtorchtrt_plugins.so`` which contains the implementation of the TensorRT plugins used by Torch-TensorRT during
compilation. This library can be ``DL_OPEN`` or ``LD_PRELOAD`` similar to other TensorRT plugin libraries.

Multi Device Safe Mode
---------------

Multi-device safe mode is a setting in Torch-TensorRT which allows the user to determine whether
the runtime checks for device consistency prior to every inference call.

There is a non-negligible, fixed cost per-inference call when multi-device safe mode is enabled, which is why
it is now disabled by default. It can be controlled via the following convenience function which
doubles as a context manager.

.. code-block:: python
# Enables Multi Device Safe Mode
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
# Disables Multi Device Safe Mode [Default Behavior]
torch_tensorrt.runtime.set_multi_device_safe_mode(False)
# Enables Multi Device Safe Mode, then resets the safe mode to its prior setting
with torch_tensorrt.runtime.set_multi_device_safe_mode(True):
...
TensorRT requires that each engine be associated with the CUDA context in the active thread from which it is invoked.
Therefore, if the device were to change in the active thread, which may be the case when invoking
engines on multiple GPUs from the same Python process, safe mode will cause Torch-TensorRT to display
an alert and switch GPUs accordingly. If safe mode were not enabled, there could be a mismatch in the engine
device and CUDA context device, which could lead the program to crash.

One technique for managing multiple TRT engines on different GPUs while not sacrificing performance for
multi-device safe mode is to use Python threads. Each thread is responsible for all of the TRT engines
on a single GPU, and the default CUDA device on each thread corresponds to the GPU for which it is
responsible (can be set via ``torch.cuda.set_device(...)``). In this way, multiple threads can be used in the same
Python script without needing to switch CUDA contexts and incur performance overhead.
8 changes: 5 additions & 3 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,17 @@ def _find_lib(name: str, paths: List[str]) -> str:
from torch_tensorrt._Device import Device # noqa: F401
from torch_tensorrt._enums import * # noqa: F403
from torch_tensorrt._Input import Input # noqa: F401
from torch_tensorrt.logging import *
from torch_tensorrt.ptq import *
from torch_tensorrt._utils import * # noqa: F403
from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt.logging import *
from torch_tensorrt.ptq import *
from torch_tensorrt.runtime import * # noqa: F403

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from torch_tensorrt import dynamo # noqa: F401
from torch_tensorrt.dynamo import backend # noqa: F401

from torch_tensorrt import dynamo # noqa: F401


def _register_with_torch() -> None:
trtorch_dir = os.path.dirname(__file__)
Expand Down
30 changes: 23 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@
from torch_tensorrt.dynamo._defaults import (
DEBUG,
DEVICE,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENGINE_CAPABILITY,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -51,17 +59,18 @@ def compile(
inputs: Any,
*,
device: Optional[Union[Device, torch.device, str]] = DEVICE,
disable_tf32: bool = False,
sparse_weights: bool = False,
disable_tf32: bool = DISABLE_TF32,
sparse_weights: bool = SPARSE_WEIGHTS,
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
refit: bool = False,
engine_capability: EngineCapability = ENGINE_CAPABILITY,
refit: bool = REFIT,
debug: bool = DEBUG,
capability: EngineCapability = EngineCapability.default,
num_avg_timing_iters: int = 1,
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
workspace_size: int = WORKSPACE_SIZE,
dla_sram_size: int = 1048576,
dla_local_dram_size: int = 1073741824,
dla_global_dram_size: int = 536870912,
dla_sram_size: int = DLA_SRAM_SIZE,
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
calibrator: object = None,
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
Expand Down Expand Up @@ -200,6 +209,13 @@ def compile(
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"refit": refit,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
}

settings = CompilationSettings(**compilation_options)
Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import torch
from tensorrt import EngineCapability
from torch_tensorrt._Device import Device

PRECISION = torch.float32
DEBUG = False
DEVICE = None
DISABLE_TF32 = False
DLA_LOCAL_DRAM_SIZE = 1073741824
DLA_GLOBAL_DRAM_SIZE = 536870912
DLA_SRAM_SIZE = 1048576
ENGINE_CAPABILITY = EngineCapability.STANDARD
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
MAX_AUX_STREAMS = None
NUM_AVG_TIMING_ITERS = 1
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
SPARSE_WEIGHTS = False
TRUNCATE_LONG_AND_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REFIT = False
REQUIRE_FULL_COMPILATION = False


Expand Down
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@
from typing import Optional, Set

import torch
from tensorrt import EngineCapability
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo._defaults import (
DEBUG,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENGINE_CAPABILITY,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -46,6 +55,14 @@ class CompilationSettings:
device (Device): GPU to compile the model on
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
sparse_weights (bool): Whether to allow the builder to use sparse weights
refit (bool): Whether to build a refittable engine
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
"""

precision: torch.dtype = PRECISION
Expand All @@ -63,3 +80,11 @@ class CompilationSettings:
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
disable_tf32: bool = DISABLE_TF32
sparse_weights: bool = SPARSE_WEIGHTS
refit: bool = REFIT
engine_capability: EngineCapability = ENGINE_CAPABILITY
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS
dla_sram_size: int = DLA_SRAM_SIZE
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
Loading

0 comments on commit b6dd22b

Please sign in to comment.