Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: engine caching #2995

Merged
merged 14 commits into from
Aug 29, 2024
65 changes: 65 additions & 0 deletions examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import torch
import torch_tensorrt
from engine_caching_example import remove_timing_cache
from transformers import BertModel

np.random.seed(0)
torch.manual_seed(0)

model = BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]


def compile_bert(iterations=3):
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
# remove timing cache and reset dynamo for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()

if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True

start.record()
compilation_kwargs = {
"use_python_runtime": False,
"enabled_precisions": {torch.float},
"truncate_double": True,
"debug": False,
"min_block_size": 1,
"make_refitable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
"engine_cache_size": 1 << 30, # 1GB
}
optimized_model = torch.compile(
model,
backend="torch_tensorrt",
options=compilation_kwargs,
)
optimized_model(*inputs)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("-----compile bert-----> compilation time:\n", times, "milliseconds")


if __name__ == "__main__":
compile_bert()
160 changes: 160 additions & 0 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os
from typing import Optional

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache

np.random.seed(0)
torch.manual_seed(0)

model = models.resnet18(pretrained=True).eval().to("cuda")
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False


def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)


def dynamo_compile(iterations=3):
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
# Mark the dim0 of inputs as dynamic
batch = torch.export.Dim("batch", min=1, max=200)
exp_program = torch.export.export(
model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
)

# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
remove_timing_cache() # remove timing cache just for engine caching messurement
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True

start.record()
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
)
# output = trt_gm(*inputs)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("----------------dynamo_compile----------------")
print("disable engine caching, used:", times[0], "ms")
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")


# Custom Engine Cache
class MyEngineCache(BaseEngineCache):
def __init__(
self,
engine_cache_dir: str,
) -> None:
self.engine_cache_dir = engine_cache_dir

def save(
self,
hash: str,
blob: bytes,
prefix: str = "blob",
):
if not os.path.exists(self.engine_cache_dir):
os.makedirs(self.engine_cache_dir, exist_ok=True)

path = os.path.join(
self.engine_cache_dir,
f"{prefix}_{hash}.bin",
)
with open(path, "wb") as f:
f.write(blob)

def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin")
if os.path.exists(path):
with open(path, "rb") as f:
blob = f.read()
return blob
return None


def torch_compile(iterations=3):
times = []
engine_cache = MyEngineCache("/tmp/your_dir")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
# remove timing cache and reset dynamo just for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()

if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True

start.record()
compiled_model = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache, # use custom engine cache
},
)
compiled_model(*inputs) # trigger the compilation
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("----------------torch_compile----------------")
print("disable engine caching, used:", times[0], "ms")
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")


if __name__ == "__main__":
dynamo_compile()
torch_compile()
31 changes: 30 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
dryrun_stats_display,
parse_non_trt_nodes,
)
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, DiskEngineCache
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
UnsupportedOperatorException,
Expand Down Expand Up @@ -82,6 +83,11 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
engine_cache_dir: Optional[str] = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: Optional[int] = _defaults.ENGINE_CACHE_SIZE,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -147,6 +153,11 @@ def compile(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -224,6 +235,17 @@ def compile(
gm = post_lowering(gm)
logger.debug("Lowered Input graph: " + str(gm.graph))

engine_cache = None
if cache_built_engines or reuse_cached_engines:
assert (
make_refitable
), "Engine caching requires make_refitable to be set to True"
engine_cache = (
custom_engine_cache
if custom_engine_cache is not None
else DiskEngineCache(engine_cache_dir, engine_cache_size)
)

compilation_options = {
"enabled_precisions": (
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
Expand Down Expand Up @@ -257,11 +279,15 @@ def compile(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
return trt_gm


Expand All @@ -270,6 +296,7 @@ def compile_module(
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
) -> torch.fx.GraphModule:
"""Compile a traced FX module

Expand All @@ -280,6 +307,7 @@ def compile_module(
arg_inputs: Inputs to the module
kwarg_inputs: kwargs to the module
settings: Compilation settings
engine_cache: Engine cache instance to store/load compiled engines
Returns:
Compiled FX GraphModule
"""
Expand Down Expand Up @@ -436,6 +464,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
submodule_inputs,
settings=settings,
name=name,
engine_cache=engine_cache,
)

trt_modules[name] = trt_module
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
TIMING_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
LAZY_ENGINE_INIT = False
CACHE_BUILT_ENGINES = True
REUSE_CACHED_ENGINES = True
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
CUSTOM_ENGINE_CACHE = None


def default_device() -> Device:
Expand Down
Loading
Loading