Skip to content

Commit

Permalink
feat: engine caching
Browse files Browse the repository at this point in the history
revert backend changes

update dynamo path

add save_engine_cache and load_engine_cache args

support customizing engine cache class

refactor and add LRU to clear cache

fix bug
  • Loading branch information
zewenli98 committed Aug 7, 2024
1 parent 19f671d commit fbfc863
Show file tree
Hide file tree
Showing 7 changed files with 493 additions and 2 deletions.
174 changes: 174 additions & 0 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import ast
import logging
import os
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache

_LOGGER: logging.Logger = logging.getLogger(__name__)


np.random.seed(0)
torch.manual_seed(0)
size = (100, 3, 224, 224)

model = models.resnet18(pretrained=True).eval().to("cuda")
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False


def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)


def dynamo_path(iterations=3):
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
# Mark the dim0 of inputs as dynamic
batch = torch.export.Dim("batch", min=1, max=200)
exp_program = torch.export.export(
model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
)

for i in range(iterations):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
remove_timing_cache() # remove timing cache for engine caching messurement
if i == 0:
save_engine_cache = False
load_engine_cache = False
else:
save_engine_cache = True
load_engine_cache = True

start.record()
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
save_engine_cache=save_engine_cache,
load_engine_cache=load_engine_cache,
engine_cache_size=1 << 30, # 1GB
)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("-----dynamo_path-----> compilation time:", times, "milliseconds")


# Custom Engine Cache
class MyEngineCache(BaseEngineCache):

def __init__(
self,
engine_cache_size: int,
engine_cache_dir: str,
) -> None:
self.total_engine_cache_size = engine_cache_size
self.available_engine_cache_size = engine_cache_size
self.engine_cache_dir = engine_cache_dir

def save(
self,
hash: str,
serialized_engine: bytes,
input_names: List[str],
output_names: List[str],
) -> bool:
path = os.path.join(
self.engine_cache_dir,
f"{hash}/engine--{input_names}--{output_names}.trt",
)
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
f.write(serialized_engine)
except Exception as e:
_LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}")
return False

_LOGGER.info(f"A TRT engine was cached to {path}")
serialized_engine_size = int(serialized_engine.nbytes)
self.available_engine_cache_size -= serialized_engine_size
return True

def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
directory = os.path.join(self.engine_cache_dir, hash)
if os.path.exists(directory):
engine_list = os.listdir(directory)
assert (
len(engine_list) == 1
), f"There are more than one engine {engine_list} under {directory}."
path = os.path.join(directory, engine_list[0])
input_names_str, output_names_str = (
engine_list[0].split(".trt")[0].split("--")[1:]
)
input_names = ast.literal_eval(input_names_str)
output_names = ast.literal_eval(output_names_str)
with open(path, "rb") as f:
serialized_engine = f.read()
return serialized_engine, input_names, output_names
else:
return None, [], []


def compile_path(iterations=3):
times = []
engine_cache = MyEngineCache(200 * (1 << 20), "/tmp/your_dir")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

for i in range(iterations):
inputs = [torch.rand(size).to("cuda")]
# remove timing cache and reset dynamo for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()

if i == 0:
save_engine_cache = False
load_engine_cache = False
else:
save_engine_cache = True
load_engine_cache = True

start.record()
compiled_model = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": use_python_runtime,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"save_engine_cache": save_engine_cache,
"load_engine_cache": load_engine_cache,
"engine_cache_instance": engine_cache, # use custom engine cache
},
)
compiled_model(*inputs) # trigger the compilation
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("-----compile_path-----> compilation time:", times, "milliseconds")


if __name__ == "__main__":
dynamo_path()
compile_path()
39 changes: 39 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
dryrun_stats_display,
parse_non_trt_nodes,
)
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, EngineCache
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
UnsupportedOperatorException,
Expand Down Expand Up @@ -83,6 +84,11 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
save_engine_cache: bool = _defaults.SAVE_ENGINE_CACHE,
load_engine_cache: bool = _defaults.LOAD_ENGINE_CACHE,
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
engine_cache_instance: Optional[BaseEngineCache] = None,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -148,6 +154,11 @@ def compile(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
save_engine_cache (bool): Whether to save the compiled TRT engines to hard disk
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
engine_cache_dir (str): Directory to store the cached TRT engines
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
engine_cache_instance (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -224,6 +235,11 @@ def compile(
gm = post_lowering(gm)
logger.debug("Lowered Input graph: " + str(gm.graph))

if engine_cache_instance is None:
engine_cache_instance = EngineCacheInstanceCreator.get_creator(
engine_cache_size, engine_cache_dir
).engine_cache_instance

compilation_options = {
"enabled_precisions": (
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
Expand Down Expand Up @@ -257,6 +273,11 @@ def compile(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"lazy_engine_init": lazy_engine_init,
"save_engine_cache": save_engine_cache,
"load_engine_cache": load_engine_cache,
"engine_cache_dir": engine_cache_dir,
"engine_cache_size": engine_cache_size,
"engine_cache_instance": engine_cache_instance,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -703,3 +724,21 @@ def convert_exported_program_to_serialized_trt_engine(

serialized_engine: bytes = interpreter_result.serialized_engine
return serialized_engine


class EngineCacheInstanceCreator:
engine_cache_creator = None

def __init__(self, engine_cache_size: int, engine_cache_dir: str) -> None:
self.engine_cache_instance = EngineCache(
engine_cache_size=engine_cache_size,
engine_cache_dir=engine_cache_dir,
)

@classmethod
def get_creator(
cls, engine_cache_size: int, engine_cache_dir: str
) -> EngineCacheInstanceCreator:
if cls.engine_cache_creator is None:
cls.engine_cache_creator = cls(engine_cache_size, engine_cache_dir)
return cls.engine_cache_creator
12 changes: 11 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._engine_caching import EngineCache

ENABLED_PRECISIONS = {dtype.f32}
DEBUG = False
Expand Down Expand Up @@ -31,8 +32,17 @@
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
TIMING_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
LAZY_ENGINE_INIT = False
SAVE_ENGINE_CACHE = True
LOAD_ENGINE_CACHE = True
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
ENGINE_CACHE_INSTANCE = EngineCache(
engine_cache_size=ENGINE_CACHE_SIZE, engine_cache_dir=ENGINE_CACHE_DIR
)


def default_device() -> Device:
Expand Down
Loading

0 comments on commit fbfc863

Please sign in to comment.