Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Aug 21, 2024
1 parent 59ba4a2 commit 748b4c6
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 278 deletions.
13 changes: 7 additions & 6 deletions examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def compile_bert(iterations=3):
torch._dynamo.reset()

if i == 0:
save_engine_cache = False
load_engine_cache = False
cache_built_engines = False
reuse_cached_engines = False
else:
save_engine_cache = True
load_engine_cache = True
cache_built_engines = True
reuse_cached_engines = True

start.record()
compilation_kwargs = {
Expand All @@ -43,8 +43,9 @@ def compile_bert(iterations=3):
"debug": False,
"min_block_size": 1,
"make_refitable": True,
"save_engine_cache": save_engine_cache,
"load_engine_cache": load_engine_cache,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
"engine_cache_size": 1 << 30, # 1GB
}
optimized_model = torch.compile(
Expand Down
89 changes: 30 additions & 59 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import ast
import logging
import os
from typing import List, Optional, Tuple
from typing import Optional

import numpy as np
import torch
Expand All @@ -10,9 +8,6 @@
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache

_LOGGER: logging.Logger = logging.getLogger(__name__)


np.random.seed(0)
torch.manual_seed(0)
size = (100, 3, 224, 224)
Expand Down Expand Up @@ -49,11 +44,11 @@ def dynamo_path(iterations=3):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
remove_timing_cache() # remove timing cache for engine caching messurement
if i == 0:
save_engine_cache = False
load_engine_cache = False
cache_built_engines = False
reuse_cached_engines = False
else:
save_engine_cache = True
load_engine_cache = True
cache_built_engines = True
reuse_cached_engines = True

start.record()
trt_gm = torch_trt.dynamo.compile(
Expand All @@ -64,8 +59,8 @@ def dynamo_path(iterations=3):
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
save_engine_cache=save_engine_cache,
load_engine_cache=load_engine_cache,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
)
end.record()
Expand All @@ -79,60 +74,36 @@ def dynamo_path(iterations=3):
class MyEngineCache(BaseEngineCache):
def __init__(
self,
engine_cache_size: int,
engine_cache_dir: str,
) -> None:
self.total_engine_cache_size = engine_cache_size
self.available_engine_cache_size = engine_cache_size
self.engine_cache_dir = engine_cache_dir

def save(
self,
hash: str,
serialized_engine: bytes,
input_names: List[str],
output_names: List[str],
) -> bool:
blob: bytes,
prefix: str = "blob",
):
path = os.path.join(
self.engine_cache_dir,
f"{hash}/engine--{input_names}--{output_names}.trt",
f"{prefix}_{hash}.bin",
)
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
f.write(serialized_engine)
except Exception as e:
_LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}")
return False

_LOGGER.info(f"A TRT engine was cached to {path}")
serialized_engine_size = int(serialized_engine.nbytes)
self.available_engine_cache_size -= serialized_engine_size
return True

def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
directory = os.path.join(self.engine_cache_dir, hash)
if os.path.exists(directory):
engine_list = os.listdir(directory)
assert (
len(engine_list) == 1
), f"There are more than one engine {engine_list} under {directory}."
path = os.path.join(directory, engine_list[0])
input_names_str, output_names_str = (
engine_list[0].split(".trt")[0].split("--")[1:]
)
input_names = ast.literal_eval(input_names_str)
output_names = ast.literal_eval(output_names_str)
os.makedirs(path, exist_ok=True)
with open(path, "wb") as f:
f.write(blob)

def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin")
if os.path.exists(path):
with open(path, "rb") as f:
serialized_engine = f.read()
return serialized_engine, input_names, output_names
else:
return None, [], []
blob = f.read()
return blob
return None


def compile_path(iterations=3):
times = []
engine_cache = MyEngineCache(200 * (1 << 20), "/tmp/your_dir")
engine_cache = MyEngineCache("/tmp/your_dir")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

Expand All @@ -147,11 +118,11 @@ def compile_path(iterations=3):
torch._dynamo.reset()

if i == 0:
save_engine_cache = False
load_engine_cache = False
cache_built_engines = False
reuse_cached_engines = False
else:
save_engine_cache = True
load_engine_cache = True
cache_built_engines = True
reuse_cached_engines = True

start.record()
compiled_model = torch.compile(
Expand All @@ -163,9 +134,9 @@ def compile_path(iterations=3):
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"save_engine_cache": save_engine_cache,
"load_engine_cache": load_engine_cache,
"engine_cache_instance": engine_cache, # use custom engine cache
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache, # use custom engine cache
},
)
compiled_model(*inputs) # trigger the compilation
Expand All @@ -178,4 +149,4 @@ def compile_path(iterations=3):

if __name__ == "__main__":
dynamo_path()
compile_path()
# compile_path()
47 changes: 13 additions & 34 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
dryrun_stats_display,
parse_non_trt_nodes,
)
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, EngineCache
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, DiskEngineCache
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
UnsupportedOperatorException,
Expand Down Expand Up @@ -84,11 +84,11 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
save_engine_cache: bool = _defaults.SAVE_ENGINE_CACHE,
load_engine_cache: bool = _defaults.LOAD_ENGINE_CACHE,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
engine_cache_instance: Optional[BaseEngineCache] = None,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -154,11 +154,11 @@ def compile(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
save_engine_cache (bool): Whether to save the compiled TRT engines to hard disk
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
engine_cache_dir (str): Directory to store the cached TRT engines
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
engine_cache_instance (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -235,10 +235,9 @@ def compile(
gm = post_lowering(gm)
logger.debug("Lowered Input graph: " + str(gm.graph))

if engine_cache_instance is None:
engine_cache_instance = EngineCacheInstanceCreator.get_creator(
engine_cache_size, engine_cache_dir
).engine_cache_instance
if cache_built_engines or reuse_cached_engines:
if custom_engine_cache is None:
custom_engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size)

compilation_options = {
"enabled_precisions": (
Expand Down Expand Up @@ -273,11 +272,9 @@ def compile(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"lazy_engine_init": lazy_engine_init,
"save_engine_cache": save_engine_cache,
"load_engine_cache": load_engine_cache,
"engine_cache_dir": engine_cache_dir,
"engine_cache_size": engine_cache_size,
"engine_cache_instance": engine_cache_instance,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": custom_engine_cache,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -724,21 +721,3 @@ def convert_exported_program_to_serialized_trt_engine(

serialized_engine: bytes = interpreter_result.serialized_engine
return serialized_engine


class EngineCacheInstanceCreator:
engine_cache_creator = None

def __init__(self, engine_cache_size: int, engine_cache_dir: str) -> None:
self.engine_cache_instance = EngineCache(
engine_cache_size=engine_cache_size,
engine_cache_dir=engine_cache_dir,
)

@classmethod
def get_creator(
cls, engine_cache_size: int, engine_cache_dir: str
) -> EngineCacheInstanceCreator:
if cls.engine_cache_creator is None:
cls.engine_cache_creator = cls(engine_cache_size, engine_cache_dir)
return cls.engine_cache_creator
9 changes: 3 additions & 6 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._engine_caching import EngineCache

ENABLED_PRECISIONS = {dtype.f32}
DEBUG = False
Expand Down Expand Up @@ -36,13 +35,11 @@
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
LAZY_ENGINE_INIT = False
SAVE_ENGINE_CACHE = True
LOAD_ENGINE_CACHE = True
CACHE_BUILT_ENGINES = True
REUSE_CACHED_ENGINES = True
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
ENGINE_CACHE_INSTANCE = EngineCache(
engine_cache_size=ENGINE_CACHE_SIZE, engine_cache_dir=ENGINE_CACHE_DIR
)
CUSTOM_ENGINE_CACHE = None


def default_device() -> Device:
Expand Down
Loading

0 comments on commit 748b4c6

Please sign in to comment.