Skip to content

Commit 1d5dd56

Browse files
authored
feat: Lazy engine initialization (#2997)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 577c5c4 commit 1d5dd56

16 files changed

+523
-107
lines changed

py/torch_tensorrt/_compile.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch.export import ExportedProgram
2828
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
2929
from torch_tensorrt.dynamo._compiler import (
30-
convert_module_to_trt_engine as dynamo_convert_module_to_trt_engine,
30+
convert_exported_program_to_serialized_trt_engine as dynamo_convert_exported_program_to_serialized_trt_engine,
3131
)
3232
from torch_tensorrt.dynamo._tracer import trace as dynamo_trace
3333

@@ -351,7 +351,7 @@ def convert_method_to_trt_engine(
351351
torchtrt_inputs = prepare_inputs(inputs)
352352
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
353353

354-
return dynamo_convert_module_to_trt_engine(
354+
return dynamo_convert_exported_program_to_serialized_trt_engine(
355355
exp_program,
356356
inputs=tuple(inputs),
357357
enabled_precisions=enabled_precisions_set,

py/torch_tensorrt/dynamo/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
logger = logging.getLogger(__name__)
88

99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
10-
from ._compiler import compile, convert_module_to_trt_engine
10+
from ._compiler import compile, convert_exported_program_to_serialized_trt_engine
1111
from ._exporter import export
1212
from ._refit import refit_module_weights
1313
from ._settings import CompilationSettings

py/torch_tensorrt/dynamo/_compiler.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def compile(
7979
dryrun: bool = _defaults.DRYRUN,
8080
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
8181
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
82+
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
8283
**kwargs: Any,
8384
) -> torch.fx.GraphModule:
8485
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -141,6 +142,7 @@ def compile(
141142
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
142143
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
143144
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
145+
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
144146
**kwargs: Any,
145147
Returns:
146148
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -236,6 +238,7 @@ def compile(
236238
"dryrun": dryrun,
237239
"hardware_compatible": hardware_compatible,
238240
"timing_cache_path": timing_cache_path,
241+
"lazy_engine_init": lazy_engine_init,
239242
}
240243

241244
settings = CompilationSettings(**compilation_options)
@@ -454,6 +457,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
454457
# Replace all FX Modules with TRT Modules
455458
for name, trt_module in trt_modules.items():
456459
setattr(partitioned_module, name, trt_module)
460+
if settings.lazy_engine_init:
461+
getattr(partitioned_module, name).setup_engine()
457462

458463
# Reset settings object to user specification after fallback to global partitioning mode
459464
if fast_partitioner_failed:
@@ -464,7 +469,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
464469
return partitioned_module
465470

466471

467-
def convert_module_to_trt_engine(
472+
def convert_exported_program_to_serialized_trt_engine(
468473
exported_program: ExportedProgram,
469474
inputs: Sequence[Any],
470475
*,
@@ -647,10 +652,5 @@ def convert_module_to_trt_engine(
647652
exc_info=True,
648653
)
649654

650-
import io
651-
652-
with io.BytesIO() as engine_bytes:
653-
engine_bytes.write(interpreter_result.engine)
654-
engine_bytearray: bytes = engine_bytes.getvalue()
655-
656-
return engine_bytearray
655+
serialized_engine: bytes = interpreter_result.serialized_engine
656+
return serialized_engine

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HARDWARE_COMPATIBLE = False
3333
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
3434
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
35+
LAZY_ENGINE_INIT = False
3536

3637

3738
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ENABLED_PRECISIONS,
1717
ENGINE_CAPABILITY,
1818
HARDWARE_COMPATIBLE,
19+
LAZY_ENGINE_INIT,
1920
MAKE_REFITABLE,
2021
MAX_AUX_STREAMS,
2122
MIN_BLOCK_SIZE,
@@ -104,3 +105,4 @@ class CompilationSettings:
104105
dryrun: Union[bool, str] = DRYRUN
105106
hardware_compatible: bool = HARDWARE_COMPATIBLE
106107
timing_cache_path: str = TIMING_CACHE_PATH
108+
lazy_engine_init: bool = LAZY_ENGINE_INIT

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import io
12
import logging
23
import os
34
import warnings
45
from datetime import datetime
56
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
67

78
import numpy as np
8-
import tensorrt as trt
99
import torch
1010
import torch.fx
1111
from torch.fx.node import _get_qualified_name
@@ -29,6 +29,7 @@
2929
from torch_tensorrt.fx.observer import Observer
3030
from torch_tensorrt.logging import TRT_LOGGER
3131

32+
import tensorrt as trt
3233
from packaging import version
3334

3435
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -43,7 +44,7 @@ class UnsupportedOperatorException(RuntimeError):
4344

4445

4546
class TRTInterpreterResult(NamedTuple):
46-
engine: Any
47+
serialized_engine: bytes
4748
input_names: Sequence[str]
4849
output_names: Sequence[str]
4950

@@ -358,9 +359,11 @@ def run(
358359
builder_config, self.compilation_settings.timing_cache_path
359360
)
360361

361-
return TRTInterpreterResult(
362-
serialized_engine, self._input_names, self._output_names
363-
)
362+
with io.BytesIO() as engine_bytes:
363+
engine_bytes.write(serialized_engine)
364+
engine_str = engine_bytes.getvalue()
365+
366+
return TRTInterpreterResult(engine_str, self._input_names, self._output_names)
364367

365368
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
366369
self._cur_node_name = get_node_name(n)

py/torch_tensorrt/dynamo/conversion/_conversion.py

+19-23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import io
43
import logging
54
from typing import List, Sequence
65

@@ -102,33 +101,30 @@ def convert_module(
102101
settings: Compilation settings
103102
name: TRT engine name
104103
Returns:
105-
_PythonTorchTensorRTModule or TorchTensorRTModule
104+
PythonTorchTensorRTModule or TorchTensorRTModule
106105
"""
107106
interpreter_result = interpret_module_to_result(module, inputs, settings)
108107

109-
if settings.use_python_runtime or not ENABLED_FEATURES.torch_tensorrt_runtime:
110-
if not settings.use_python_runtime:
111-
logger.info(
112-
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
113-
)
114-
return PythonTorchTensorRTModule(
115-
engine=interpreter_result.engine,
116-
input_names=list(interpreter_result.input_names),
117-
output_names=list(interpreter_result.output_names),
118-
settings=settings,
119-
)
108+
rt_cls = PythonTorchTensorRTModule
109+
110+
if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:
120111

121-
else:
122112
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule
123113

124-
with io.BytesIO() as engine_bytes:
125-
engine_bytes.write(interpreter_result.engine)
126-
engine_str = engine_bytes.getvalue()
114+
rt_cls = TorchTensorRTModule
115+
116+
elif (
117+
not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime
118+
):
127119

128-
return TorchTensorRTModule(
129-
serialized_engine=engine_str,
130-
name=name,
131-
input_binding_names=list(interpreter_result.input_names),
132-
output_binding_names=list(interpreter_result.output_names),
133-
settings=settings,
120+
logger.info(
121+
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
134122
)
123+
124+
return rt_cls(
125+
serialized_engine=interpreter_result.serialized_engine,
126+
input_binding_names=list(interpreter_result.input_names),
127+
output_binding_names=list(interpreter_result.output_names),
128+
name=name,
129+
settings=settings,
130+
)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+51-21
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from contextlib import nullcontext
55
from typing import Any, Dict, List, Optional, Sequence, Tuple
66

7-
import tensorrt as trt
87
import torch
8+
import torch_tensorrt
99
from torch.nn import Module
1010
from torch_tensorrt._Device import Device
1111
from torch_tensorrt._enums import dtype
@@ -18,7 +18,7 @@
1818
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1919
from torch_tensorrt.logging import TRT_LOGGER
2020

21-
import torch_tensorrt
21+
import tensorrt as trt
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -32,17 +32,45 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
3232

3333
def __init__(
3434
self,
35-
engine: bytes,
36-
input_names: Optional[List[str]] = None,
37-
output_names: Optional[List[str]] = None,
35+
serialized_engine: Optional[bytes] = None,
36+
input_binding_names: Optional[List[str]] = None,
37+
output_binding_names: Optional[List[str]] = None,
38+
*,
39+
name: str = "",
3840
settings: CompilationSettings = CompilationSettings(),
3941
):
42+
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
43+
a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine
44+
45+
Arguments:
46+
serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray
47+
input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules
48+
output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned
49+
50+
Keyword Arguments:
51+
name (str): Name for module
52+
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
53+
54+
Example:
55+
56+
.. code-block:: py
57+
58+
trt_module = PythonTorchTensorRTModule(
59+
engine_str,
60+
input_binding_names=["x"],
61+
output_binding_names=["output"],
62+
name="my_module",
63+
settings=CompilationSettings(device=torch.cuda.current_device)
64+
)
65+
66+
"""
4067
super(PythonTorchTensorRTModule, self).__init__()
4168
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)
4269

4370
# Run multi-gpu device check to validate engine instantiation
4471
multi_gpu_device_check()
4572

73+
self.name = name
4674
self.input_buffers: List[torch.Tensor] = []
4775
self.output_buffers: List[torch.Tensor] = []
4876
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
@@ -55,9 +83,13 @@ def __init__(
5583
# Unused currently - to be used by Dynamic Shape support implementation
5684
self.memory_pool = None
5785

58-
self.engine = engine
59-
self.input_names = input_names if input_names is not None else []
60-
self.output_names = output_names if output_names is not None else []
86+
self.serialized_engine = serialized_engine
87+
self.input_names = (
88+
input_binding_names if input_binding_names is not None else []
89+
)
90+
self.output_names = (
91+
output_binding_names if output_binding_names is not None else []
92+
)
6193
self.initialized = False
6294
self.target_device_id = (
6395
settings.device.gpu_id
@@ -69,12 +101,15 @@ def __init__(
69101
)
70102
self.profiling_enabled = settings.debug if settings.debug is not None else False
71103
self.settings = settings
72-
self._initialize()
104+
self.engine = None
105+
106+
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
107+
self.setup_engine()
73108

74-
def _initialize(self) -> None:
109+
def setup_engine(self) -> None:
75110
self.initialized = True
76111
runtime = trt.Runtime(TRT_LOGGER)
77-
self.engine = runtime.deserialize_cuda_engine(self.engine)
112+
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
78113
self.context = self.engine.create_execution_context()
79114

80115
assert self.engine.num_io_tensors == (
@@ -114,8 +149,7 @@ def _check_initialized(self) -> None:
114149
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
115150

116151
def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
117-
self._check_initialized()
118-
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
152+
state_dict[prefix + "engine"] = self.serialized_engine
119153
state_dict[prefix + "input_names"] = self.input_names
120154
state_dict[prefix + "output_names"] = self.output_names
121155

@@ -129,17 +163,13 @@ def _load_from_state_dict(
129163
unexpected_keys: Any,
130164
error_msgs: Any,
131165
) -> None:
132-
engine_bytes = state_dict[prefix + "engine"]
166+
self.serialized_engine = state_dict[prefix + "engine"]
167+
self.input_names = state_dict[prefix + "input_names"]
168+
self.output_names = state_dict[prefix + "output_names"]
133169

134170
# Run multi-gpu device check to validate engine instantiation
135171
multi_gpu_device_check()
136-
137-
runtime = trt.Runtime(TRT_LOGGER)
138-
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
139-
140-
self.input_names = state_dict[prefix + "input_names"]
141-
self.output_names = state_dict[prefix + "output_names"]
142-
self._initialize()
172+
self.setup_engine()
143173

144174
def __getstate__(self) -> Dict[str, Any]:
145175
state = self.__dict__.copy()

0 commit comments

Comments
 (0)