Skip to content

Commit fbfc863

Browse files
committed
feat: engine caching
revert backend changes update dynamo path add save_engine_cache and load_engine_cache args support customizing engine cache class refactor and add LRU to clear cache fix bug
1 parent 19f671d commit fbfc863

File tree

7 files changed

+493
-2
lines changed

7 files changed

+493
-2
lines changed
+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import ast
2+
import logging
3+
import os
4+
from typing import List, Optional, Tuple
5+
6+
import numpy as np
7+
import torch
8+
import torch_tensorrt as torch_trt
9+
import torchvision.models as models
10+
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
11+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
12+
13+
_LOGGER: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
np.random.seed(0)
17+
torch.manual_seed(0)
18+
size = (100, 3, 224, 224)
19+
20+
model = models.resnet18(pretrained=True).eval().to("cuda")
21+
enabled_precisions = {torch.float}
22+
debug = False
23+
min_block_size = 1
24+
use_python_runtime = False
25+
26+
27+
def remove_timing_cache(path=TIMING_CACHE_PATH):
28+
if os.path.exists(path):
29+
os.remove(path)
30+
31+
32+
def dynamo_path(iterations=3):
33+
times = []
34+
start = torch.cuda.Event(enable_timing=True)
35+
end = torch.cuda.Event(enable_timing=True)
36+
37+
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
38+
# Mark the dim0 of inputs as dynamic
39+
batch = torch.export.Dim("batch", min=1, max=200)
40+
exp_program = torch.export.export(
41+
model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
42+
)
43+
44+
for i in range(iterations):
45+
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
46+
remove_timing_cache() # remove timing cache for engine caching messurement
47+
if i == 0:
48+
save_engine_cache = False
49+
load_engine_cache = False
50+
else:
51+
save_engine_cache = True
52+
load_engine_cache = True
53+
54+
start.record()
55+
trt_gm = torch_trt.dynamo.compile(
56+
exp_program,
57+
tuple(inputs),
58+
use_python_runtime=use_python_runtime,
59+
enabled_precisions=enabled_precisions,
60+
debug=debug,
61+
min_block_size=min_block_size,
62+
make_refitable=True,
63+
save_engine_cache=save_engine_cache,
64+
load_engine_cache=load_engine_cache,
65+
engine_cache_size=1 << 30, # 1GB
66+
)
67+
end.record()
68+
torch.cuda.synchronize()
69+
times.append(start.elapsed_time(end))
70+
71+
print("-----dynamo_path-----> compilation time:", times, "milliseconds")
72+
73+
74+
# Custom Engine Cache
75+
class MyEngineCache(BaseEngineCache):
76+
77+
def __init__(
78+
self,
79+
engine_cache_size: int,
80+
engine_cache_dir: str,
81+
) -> None:
82+
self.total_engine_cache_size = engine_cache_size
83+
self.available_engine_cache_size = engine_cache_size
84+
self.engine_cache_dir = engine_cache_dir
85+
86+
def save(
87+
self,
88+
hash: str,
89+
serialized_engine: bytes,
90+
input_names: List[str],
91+
output_names: List[str],
92+
) -> bool:
93+
path = os.path.join(
94+
self.engine_cache_dir,
95+
f"{hash}/engine--{input_names}--{output_names}.trt",
96+
)
97+
try:
98+
os.makedirs(os.path.dirname(path), exist_ok=True)
99+
with open(path, "wb") as f:
100+
f.write(serialized_engine)
101+
except Exception as e:
102+
_LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}")
103+
return False
104+
105+
_LOGGER.info(f"A TRT engine was cached to {path}")
106+
serialized_engine_size = int(serialized_engine.nbytes)
107+
self.available_engine_cache_size -= serialized_engine_size
108+
return True
109+
110+
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
111+
directory = os.path.join(self.engine_cache_dir, hash)
112+
if os.path.exists(directory):
113+
engine_list = os.listdir(directory)
114+
assert (
115+
len(engine_list) == 1
116+
), f"There are more than one engine {engine_list} under {directory}."
117+
path = os.path.join(directory, engine_list[0])
118+
input_names_str, output_names_str = (
119+
engine_list[0].split(".trt")[0].split("--")[1:]
120+
)
121+
input_names = ast.literal_eval(input_names_str)
122+
output_names = ast.literal_eval(output_names_str)
123+
with open(path, "rb") as f:
124+
serialized_engine = f.read()
125+
return serialized_engine, input_names, output_names
126+
else:
127+
return None, [], []
128+
129+
130+
def compile_path(iterations=3):
131+
times = []
132+
engine_cache = MyEngineCache(200 * (1 << 20), "/tmp/your_dir")
133+
start = torch.cuda.Event(enable_timing=True)
134+
end = torch.cuda.Event(enable_timing=True)
135+
136+
for i in range(iterations):
137+
inputs = [torch.rand(size).to("cuda")]
138+
# remove timing cache and reset dynamo for engine caching messurement
139+
remove_timing_cache()
140+
torch._dynamo.reset()
141+
142+
if i == 0:
143+
save_engine_cache = False
144+
load_engine_cache = False
145+
else:
146+
save_engine_cache = True
147+
load_engine_cache = True
148+
149+
start.record()
150+
compiled_model = torch.compile(
151+
model,
152+
backend="tensorrt",
153+
options={
154+
"use_python_runtime": use_python_runtime,
155+
"enabled_precisions": enabled_precisions,
156+
"debug": debug,
157+
"min_block_size": min_block_size,
158+
"make_refitable": True,
159+
"save_engine_cache": save_engine_cache,
160+
"load_engine_cache": load_engine_cache,
161+
"engine_cache_instance": engine_cache, # use custom engine cache
162+
},
163+
)
164+
compiled_model(*inputs) # trigger the compilation
165+
end.record()
166+
torch.cuda.synchronize()
167+
times.append(start.elapsed_time(end))
168+
169+
print("-----compile_path-----> compilation time:", times, "milliseconds")
170+
171+
172+
if __name__ == "__main__":
173+
dynamo_path()
174+
compile_path()

py/torch_tensorrt/dynamo/_compiler.py

+39
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
dryrun_stats_display,
1919
parse_non_trt_nodes,
2020
)
21+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, EngineCache
2122
from torch_tensorrt.dynamo.conversion import (
2223
CompilationSettings,
2324
UnsupportedOperatorException,
@@ -83,6 +84,11 @@ def compile(
8384
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
8485
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
8586
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
87+
save_engine_cache: bool = _defaults.SAVE_ENGINE_CACHE,
88+
load_engine_cache: bool = _defaults.LOAD_ENGINE_CACHE,
89+
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
90+
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
91+
engine_cache_instance: Optional[BaseEngineCache] = None,
8692
**kwargs: Any,
8793
) -> torch.fx.GraphModule:
8894
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -148,6 +154,11 @@ def compile(
148154
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
149155
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
150156
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
157+
save_engine_cache (bool): Whether to save the compiled TRT engines to hard disk
158+
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
159+
engine_cache_dir (str): Directory to store the cached TRT engines
160+
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
161+
engine_cache_instance (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache
151162
**kwargs: Any,
152163
Returns:
153164
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -224,6 +235,11 @@ def compile(
224235
gm = post_lowering(gm)
225236
logger.debug("Lowered Input graph: " + str(gm.graph))
226237

238+
if engine_cache_instance is None:
239+
engine_cache_instance = EngineCacheInstanceCreator.get_creator(
240+
engine_cache_size, engine_cache_dir
241+
).engine_cache_instance
242+
227243
compilation_options = {
228244
"enabled_precisions": (
229245
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
@@ -257,6 +273,11 @@ def compile(
257273
"hardware_compatible": hardware_compatible,
258274
"timing_cache_path": timing_cache_path,
259275
"lazy_engine_init": lazy_engine_init,
276+
"save_engine_cache": save_engine_cache,
277+
"load_engine_cache": load_engine_cache,
278+
"engine_cache_dir": engine_cache_dir,
279+
"engine_cache_size": engine_cache_size,
280+
"engine_cache_instance": engine_cache_instance,
260281
}
261282

262283
settings = CompilationSettings(**compilation_options)
@@ -703,3 +724,21 @@ def convert_exported_program_to_serialized_trt_engine(
703724

704725
serialized_engine: bytes = interpreter_result.serialized_engine
705726
return serialized_engine
727+
728+
729+
class EngineCacheInstanceCreator:
730+
engine_cache_creator = None
731+
732+
def __init__(self, engine_cache_size: int, engine_cache_dir: str) -> None:
733+
self.engine_cache_instance = EngineCache(
734+
engine_cache_size=engine_cache_size,
735+
engine_cache_dir=engine_cache_dir,
736+
)
737+
738+
@classmethod
739+
def get_creator(
740+
cls, engine_cache_size: int, engine_cache_dir: str
741+
) -> EngineCacheInstanceCreator:
742+
if cls.engine_cache_creator is None:
743+
cls.engine_cache_creator = cls(engine_cache_size, engine_cache_dir)
744+
return cls.engine_cache_creator

py/torch_tensorrt/dynamo/_defaults.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch_tensorrt._Device import Device
66
from torch_tensorrt._enums import EngineCapability, dtype
7+
from torch_tensorrt.dynamo._engine_caching import EngineCache
78

89
ENABLED_PRECISIONS = {dtype.f32}
910
DEBUG = False
@@ -31,8 +32,17 @@
3132
DRYRUN = False
3233
HARDWARE_COMPATIBLE = False
3334
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
34-
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
35+
TIMING_CACHE_PATH = os.path.join(
36+
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
37+
)
3538
LAZY_ENGINE_INIT = False
39+
SAVE_ENGINE_CACHE = True
40+
LOAD_ENGINE_CACHE = True
41+
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
42+
ENGINE_CACHE_SIZE = 1073741824
43+
ENGINE_CACHE_INSTANCE = EngineCache(
44+
engine_cache_size=ENGINE_CACHE_SIZE, engine_cache_dir=ENGINE_CACHE_DIR
45+
)
3646

3747

3848
def default_device() -> Device:

0 commit comments

Comments
 (0)