From 95cc532aa6a11c8bb55bc43bc2b04a76d264577e Mon Sep 17 00:00:00 2001 From: "Zewen (Evan) Li" Date: Thu, 29 Aug 2024 16:35:29 -0700 Subject: [PATCH] feat: engine caching (#2995) --- .../dynamo/engine_caching_bert_example.py | 65 ++++ examples/dynamo/engine_caching_example.py | 160 +++++++++ py/torch_tensorrt/dynamo/_compiler.py | 31 +- py/torch_tensorrt/dynamo/_defaults.py | 9 +- py/torch_tensorrt/dynamo/_engine_caching.py | 251 ++++++++++++++ py/torch_tensorrt/dynamo/_settings.py | 6 + py/torch_tensorrt/dynamo/backend/backends.py | 7 +- .../dynamo/conversion/_TRTInterpreter.py | 80 ++++- .../dynamo/conversion/_conversion.py | 37 +-- py/torch_tensorrt/dynamo/utils.py | 30 +- .../conversion/test_bitwise_and_aten.py | 7 +- .../conversion/test_embedding_bag_aten.py | 7 +- .../conversion/test_index_select_aten.py | 7 +- tests/py/dynamo/models/test_dtype_support.py | 14 + tests/py/dynamo/models/test_dyn_models.py | 14 + tests/py/dynamo/models/test_engine_cache.py | 313 ++++++++++++++++++ .../dynamo/models/test_export_kwargs_serde.py | 14 + tests/py/dynamo/models/test_export_serde.py | 23 +- tests/py/dynamo/models/test_models.py | 10 + tests/py/dynamo/models/test_models_export.py | 14 + tests/py/dynamo/runtime/test_001_streams.py | 2 + .../runtime/test_002_lazy_engine_init.py | 10 + 22 files changed, 1069 insertions(+), 42 deletions(-) create mode 100644 examples/dynamo/engine_caching_bert_example.py create mode 100644 examples/dynamo/engine_caching_example.py create mode 100644 py/torch_tensorrt/dynamo/_engine_caching.py create mode 100644 tests/py/dynamo/models/test_engine_cache.py diff --git a/examples/dynamo/engine_caching_bert_example.py b/examples/dynamo/engine_caching_bert_example.py new file mode 100644 index 0000000000..43cfc5f15a --- /dev/null +++ b/examples/dynamo/engine_caching_bert_example.py @@ -0,0 +1,65 @@ +import numpy as np +import torch +import torch_tensorrt +from engine_caching_example import remove_timing_cache +from transformers import BertModel + +np.random.seed(0) +torch.manual_seed(0) + +model = BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval() +inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), +] + + +def compile_bert(iterations=3): + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + for i in range(iterations): + # remove timing cache and reset dynamo for engine caching messurement + remove_timing_cache() + torch._dynamo.reset() + + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + compilation_kwargs = { + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "truncate_double": True, + "debug": False, + "min_block_size": 1, + "make_refitable": True, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache", + "engine_cache_size": 1 << 30, # 1GB + } + optimized_model = torch.compile( + model, + backend="torch_tensorrt", + options=compilation_kwargs, + ) + optimized_model(*inputs) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + print("-----compile bert-----> compilation time:\n", times, "milliseconds") + + +if __name__ == "__main__": + compile_bert() diff --git a/examples/dynamo/engine_caching_example.py b/examples/dynamo/engine_caching_example.py new file mode 100644 index 0000000000..2d1018bb6e --- /dev/null +++ b/examples/dynamo/engine_caching_example.py @@ -0,0 +1,160 @@ +import os +from typing import Optional + +import numpy as np +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache + +np.random.seed(0) +torch.manual_seed(0) + +model = models.resnet18(pretrained=True).eval().to("cuda") +enabled_precisions = {torch.float} +debug = False +min_block_size = 1 +use_python_runtime = False + + +def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + +def dynamo_compile(iterations=3): + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + for i in range(iterations): + inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")] + remove_timing_cache() # remove timing cache just for engine caching messurement + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_size=1 << 30, # 1GB + ) + # output = trt_gm(*inputs) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + print("----------------dynamo_compile----------------") + print("disable engine caching, used:", times[0], "ms") + print("enable engine caching to cache engines, used:", times[1], "ms") + print("enable engine caching to reuse engines, used:", times[2], "ms") + + +# Custom Engine Cache +class MyEngineCache(BaseEngineCache): + def __init__( + self, + engine_cache_dir: str, + ) -> None: + self.engine_cache_dir = engine_cache_dir + + def save( + self, + hash: str, + blob: bytes, + prefix: str = "blob", + ): + if not os.path.exists(self.engine_cache_dir): + os.makedirs(self.engine_cache_dir, exist_ok=True) + + path = os.path.join( + self.engine_cache_dir, + f"{prefix}_{hash}.bin", + ) + with open(path, "wb") as f: + f.write(blob) + + def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]: + path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin") + if os.path.exists(path): + with open(path, "rb") as f: + blob = f.read() + return blob + return None + + +def torch_compile(iterations=3): + times = [] + engine_cache = MyEngineCache("/tmp/your_dir") + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + for i in range(iterations): + inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + # remove timing cache and reset dynamo just for engine caching messurement + remove_timing_cache() + torch._dynamo.reset() + + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + compiled_model = torch.compile( + model, + backend="tensorrt", + options={ + "use_python_runtime": True, + "enabled_precisions": enabled_precisions, + "debug": debug, + "min_block_size": min_block_size, + "make_refitable": True, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "custom_engine_cache": engine_cache, # use custom engine cache + }, + ) + compiled_model(*inputs) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + print("----------------torch_compile----------------") + print("disable engine caching, used:", times[0], "ms") + print("enable engine caching to cache engines, used:", times[1], "ms") + print("enable engine caching to reuse engines, used:", times[2], "ms") + + +if __name__ == "__main__": + dynamo_compile() + torch_compile() diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index a4849f257e..c28702f451 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -18,6 +18,7 @@ dryrun_stats_display, parse_non_trt_nodes, ) +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, DiskEngineCache from torch_tensorrt.dynamo.conversion import ( CompilationSettings, UnsupportedOperatorException, @@ -82,6 +83,11 @@ def compile( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, + cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, + reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, + engine_cache_dir: Optional[str] = _defaults.ENGINE_CACHE_DIR, + engine_cache_size: Optional[int] = _defaults.ENGINE_CACHE_SIZE, + custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -147,6 +153,11 @@ def compile( hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. + cache_built_engines (bool): Whether to save the compiled TRT engines to storage + reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage + engine_cache_dir (Optional[str]): Directory to store the cached TRT engines + engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default + custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -224,6 +235,17 @@ def compile( gm = post_lowering(gm) logger.debug("Lowered Input graph: " + str(gm.graph)) + engine_cache = None + if cache_built_engines or reuse_cached_engines: + assert ( + make_refitable + ), "Engine caching requires make_refitable to be set to True" + engine_cache = ( + custom_engine_cache + if custom_engine_cache is not None + else DiskEngineCache(engine_cache_dir, engine_cache_size) + ) + compilation_options = { "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS @@ -257,11 +279,15 @@ def compile( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, "lazy_engine_init": lazy_engine_init, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, } settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings) + trt_gm = compile_module( + gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache + ) return trt_gm @@ -270,6 +296,7 @@ def compile_module( sample_arg_inputs: Sequence[Input], sample_kwarg_inputs: Optional[dict[Any, Any]] = None, settings: CompilationSettings = CompilationSettings(), + engine_cache: Optional[BaseEngineCache] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -280,6 +307,7 @@ def compile_module( arg_inputs: Inputs to the module kwarg_inputs: kwargs to the module settings: Compilation settings + engine_cache: Engine cache instance to store/load compiled engines Returns: Compiled FX GraphModule """ @@ -436,6 +464,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: submodule_inputs, settings=settings, name=name, + engine_cache=engine_cache, ) trt_modules[name] = trt_module diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 2696e26936..83e85cb3c7 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -31,8 +31,15 @@ DRYRUN = False HARDWARE_COMPATIBLE = False SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} -TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin") +TIMING_CACHE_PATH = os.path.join( + tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" +) LAZY_ENGINE_INIT = False +CACHE_BUILT_ENGINES = True +REUSE_CACHED_ENGINES = True +ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") +ENGINE_CACHE_SIZE = 1073741824 +CUSTOM_ENGINE_CACHE = None def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_engine_caching.py b/py/torch_tensorrt/dynamo/_engine_caching.py new file mode 100644 index 0000000000..c8ff7aba50 --- /dev/null +++ b/py/torch_tensorrt/dynamo/_engine_caching.py @@ -0,0 +1,251 @@ +import copy +import logging +import os +import pickle +import shutil +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, cast + +import torch +from torch._inductor.codecache import FxGraphCachePickler +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class BaseEngineCache(ABC): + + @abstractmethod + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + pass + + @staticmethod + def get_hash(gm: torch.fx.GraphModule) -> str: + """Get the hash value of the GraphModule + + Args: + gm (torch.fx.GraphModule): GraphModule to hash + + Returns: + str: hash value of the GraphModule + """ + # parameters are set to 0 + with unset_fake_temporarily(): + new_gm = copy.deepcopy(gm) + for name, param in new_gm.named_parameters(): + param.data.zero_() + + hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm)) + + return hash_val + + @staticmethod + def pack( + serialized_engine: bytes, + input_names: List[str], + output_names: List[str], + weight_name_map: Optional[Dict[Any, Any]], + ) -> bytes: + """Pack serialized engine, input names, output names, and weight map into a single blob + + Args: + serialized_engine (bytes): serialized TRT engine + input_names (List[str]): input names of TRT engine + output_names (List[str]): output names of TRT engine + weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting + + Returns: + bytes: packed blob + """ + return pickle.dumps( + { + "serialized_engine": bytes(serialized_engine), + "input_names": input_names, + "output_names": output_names, + "weight_name_map": weight_name_map, + } + ) + + @staticmethod + def unpack( + packed_obj: bytes, + ) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]: + """Unpack packed blob into serialized engine, input names, output names, and weight map + + Args: + packed_obj (bytes): packed blob + + Returns: + Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map + """ + unpacked = pickle.loads(packed_obj) + return ( + unpacked["serialized_engine"], + unpacked["input_names"], + unpacked["output_names"], + unpacked["weight_name_map"], + ) + + @abstractmethod + def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None: + """Store blob in cache + + Args: + hash (str): hash value of the GraphModule + blob (bytes): packed blob + """ + pass + + @abstractmethod + def load(self, hash: str, *args: Any, **kwargs: Any) -> Optional[bytes]: + """Load blob from storage + + Args: + hash (str): hash value of the GraphModule + + Returns: + Optional[bytes]: blob or None if doesn't hit + """ + pass + + +class DiskEngineCache(BaseEngineCache): + dir2hash2size_map: Dict[str, Dict[str, int]] = ( + {} + ) # dir2hash2size_map["engine_cache_dir"]["hash"] = size + + def __init__( + self, + engine_cache_dir: str, + engine_cache_size: int, + ) -> None: + + def get_dir_size(path: str) -> int: + total = 0 + with os.scandir(path) as it: + for entry in it: + if entry.is_file(): + total += entry.stat().st_size + elif entry.is_dir(): + total += get_dir_size(entry.path) + return total + + if not os.path.exists(engine_cache_dir): + os.makedirs(engine_cache_dir, exist_ok=True) + self.engine_cache_dir = engine_cache_dir + self.total_engine_cache_size = engine_cache_size + self.available_engine_cache_size = engine_cache_size - get_dir_size( + engine_cache_dir + ) + if engine_cache_dir not in DiskEngineCache.dir2hash2size_map: + DiskEngineCache.dir2hash2size_map[engine_cache_dir] = {} + + def has_available_cache_size(self, needed_size: int) -> bool: + """Check if the cache has available space for saving object + + Args: + needed_size (int): needed size for saving object + + Returns: + bool: whether the cache has available size for saving object + """ + return needed_size <= self.available_engine_cache_size + + def clear_cache(self, needed_min_size: int) -> None: + """Clear the cache to make sure at least `needed_min_size` bytes are available, if possible + + Args: + needed_min_size (int): the minimum needed size + """ + + def LRU() -> None: + """Clear the Least Recently Used engine in the cache""" + # Get the list of engine directories + engines_hash_values = os.listdir(self.engine_cache_dir) + # Sort the engine directories by modification time (oldest first) + engines_hash_values.sort( + key=lambda x: os.path.getmtime(os.path.join(self.engine_cache_dir, x)) + ) + # Iterate over the engine directories and remove the oldest ones until enough space is available + for engine_hash in engines_hash_values: + if self.available_engine_cache_size >= needed_min_size: + break + engine_path = os.path.join(self.engine_cache_dir, engine_hash) + try: + # Remove the entire directory + shutil.rmtree(engine_path) + # Update the available cache size + self.available_engine_cache_size += ( + DiskEngineCache.dir2hash2size_map[self.engine_cache_dir].pop( + engine_hash, 0 + ) + ) + _LOGGER.info( + f"Removed the engine cache at {engine_path}, available cache size: {self.available_engine_cache_size} bytes." + ) + except Exception as e: + _LOGGER.warning( + f"Failed to clear the engine cache at {engine_path}: {e}" + ) + + if needed_min_size > self.total_engine_cache_size: + _LOGGER.warning( + f"The needed minimum size {needed_min_size} is larger than the total cache size {self.total_engine_cache_size}. Nothing will be cleared." + ) + else: + LRU() + + def save( + self, + hash: str, + blob: bytes, + ) -> None: + blob_size = len(blob) + if blob_size > self.total_engine_cache_size: + _LOGGER.warning( + f"The serialized engine cannot be saved because the size {blob_size} is larger than the total cache size {self.total_engine_cache_size}." + ) + return + + if not self.has_available_cache_size(blob_size): + self.clear_cache(blob_size) + + if self.has_available_cache_size(blob_size): + DiskEngineCache.dir2hash2size_map[self.engine_cache_dir][hash] = blob_size + self.available_engine_cache_size -= blob_size + directory = os.path.join(self.engine_cache_dir, hash) + if not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) + + blob_path = os.path.join( + directory, + "blob.bin", + ) + try: + with open(blob_path, "wb") as f: + f.write(blob) + _LOGGER.info(f"The blob was saved to {blob_path}") + except Exception as e: + del DiskEngineCache.dir2hash2size_map[self.engine_cache_dir][hash] + self.available_engine_cache_size += blob_size + shutil.rmtree(directory) + _LOGGER.warning(f"Failed to save the blob to {blob_path}: {e}") + + else: + _LOGGER.warning( + f"The size {blob_size} is still larger than the available cache size {self.available_engine_cache_size}." + ) + + def load(self, hash: str) -> Optional[bytes]: + directory = os.path.join(self.engine_cache_dir, hash) + if os.path.exists(directory): + blob_path = os.path.join(directory, "blob.bin") + if os.path.exists(blob_path): + with open(blob_path, "rb") as f: + blob = f.read() + return blob + return None diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 4a9792d3dc..063f6f3718 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -6,6 +6,7 @@ from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, + CACHE_BUILT_ENGINES, DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -24,6 +25,7 @@ OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, REQUIRE_FULL_COMPILATION, + REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, @@ -74,6 +76,8 @@ class CompilationSettings: output to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + cache_built_engines (bool): Whether to save the compiled TRT engines to storage + reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -106,3 +110,5 @@ class CompilationSettings: hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH lazy_engine_init: bool = LAZY_ENGINE_INIT + cache_built_engines: bool = CACHE_BUILT_ENGINES + reuse_cached_engines: bool = REUSE_CACHED_ENGINES diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ae3cb38f2d..605d963a50 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -48,14 +48,15 @@ def torch_tensorrt_backend( def aot_torch_tensorrt_aten_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any ) -> torch.nn.Module: - settings = parse_dynamo_kwargs(kwargs) - return _pretraced_backend(gm, sample_inputs, settings) + settings, engine_cache = parse_dynamo_kwargs(kwargs) + return _pretraced_backend(gm, sample_inputs, settings, engine_cache) def _pretraced_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], settings: CompilationSettings = CompilationSettings(), + engine_cache: Any = None, ) -> torch.fx.GraphModule | Callable[..., Any]: """Helper function to manage translation of traced FX module to TRT engines @@ -63,6 +64,7 @@ def _pretraced_backend( module: FX GraphModule to convert inputs: Inputs to the module settings: Compilation settings + engine_cache: Engine cache instance Returns: Compiled FX GraphModule """ @@ -109,6 +111,7 @@ def _pretraced_backend( gm, torchtrt_inputs, settings=settings, + engine_cache=engine_cache, ) return trt_compiled except (AssertionError, RuntimeError): diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 9fef61961b..3c97c8347a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -18,6 +18,7 @@ ) import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -26,6 +27,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -38,11 +40,10 @@ get_node_name, get_trt_tensor, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -71,6 +72,7 @@ def __init__( logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, output_dtypes: Optional[Sequence[dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), + engine_cache: Optional[BaseEngineCache] = None, ): super().__init__(module) @@ -126,6 +128,9 @@ def __init__( self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {} self.weight_name_map: Optional[dict[str, Any]] = None + # Engine cache for storing and reusing TRT engines + self.engine_cache = engine_cache + def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -323,6 +328,7 @@ def _save_timing_cache( This is called after a TensorRT engine is built. Save the timing cache """ timing_cache = builder_config.get_timing_cache() + os.makedirs(os.path.dirname(timing_cache_path), exist_ok=True) with open(timing_cache_path, "wb") as timing_cache_file: timing_cache_file.write(memoryview(timing_cache.serialize())) @@ -428,9 +434,8 @@ def _save_weight_mapping(self) -> None: """ _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping - sd = self.module.state_dict() torch_device = to_torch_device(self.compilation_settings.device) - gm_is_on_cuda = list(sd.values())[0].device.type == "cuda" + gm_is_on_cuda = get_model_device(self.module).type == "cuda" if not gm_is_on_cuda: # If the model original position is on CPU, move it GPU sd = { @@ -516,15 +521,71 @@ def run( Args: strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. algorithm_selector: set up algorithm selection for certain layer + tactic_sources: set up tactic sources for certain layer Return: TRTInterpreterResult """ + # self.engine_cache could be None if: + # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or + # 2) both cache_built_engines and reuse_cached_engines are False + if self.engine_cache is not None: + if ( + self.compilation_settings.cache_built_engines + or self.compilation_settings.reuse_cached_engines + ): + hash_val = self.engine_cache.get_hash(self.module) + + if self.compilation_settings.reuse_cached_engines: + # query the cached TRT engine + blob = self.engine_cache.load(hash_val) + if blob is not None: # hit the cache + serialized_engine, input_names, output_names, weight_name_map = ( + self.engine_cache.unpack(blob) + ) + self._input_names = input_names + self._output_names = output_names + self.weight_name_map = weight_name_map + _LOGGER.info( + "Found the cached engine that corresponds to this graph. It is directly loaded." + ) + + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) + + # TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers. + # We set weight_name_map=None to use slow refit anyway for now. Will fix it in the future. + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=None, + ) + + serialized_engine = engine.serialize() + + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + engine_str = engine_bytes.getvalue() + + return TRTInterpreterResult( + engine_str, + self._input_names, + self._output_names, + self.weight_name_map, + ) + self._construct_trt_network_def() if self.compilation_settings.make_refitable: self._save_weight_mapping() build_engine_start_time = datetime.now() + _LOGGER.info("Not found cached TRT engines. Start building engine.") builder_config = self._populate_trt_builder_config( strict_type_constraints, algorithm_selector, tactic_sources @@ -547,6 +608,17 @@ def run( self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) + if ( + self.engine_cache is not None + and self.compilation_settings.cache_built_engines + ): + blob = self.engine_cache.pack( + serialized_engine, + self._input_names, + self._output_names, + self.weight_name_map, + ) + self.engine_cache.save(hash_val, blob) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e0643cf996..cd38ce56e6 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -10,6 +10,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( TRTInterpreter, @@ -76,6 +77,7 @@ def interpret_module_to_result( settings: CompilationSettings = CompilationSettings(), arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, + engine_cache: Optional[BaseEngineCache] = None, ) -> TRTInterpreterResult: """Interpret an FX module to a TRTInterpreterResult Args: @@ -85,6 +87,7 @@ def interpret_module_to_result( arg_inputs: Sequence of Tensors representing inputs to the module. kwarg_inputs: A dictionary of Tensors representing inputs to the module. settings: Compilation settings + engine_cache: Engine cache instance Returns: TRTInterpreterResult """ @@ -111,6 +114,7 @@ def interpret_module_to_result( logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, compilation_settings=settings, + engine_cache=engine_cache, ) interpreter_result = interpreter.run() @@ -122,6 +126,7 @@ def convert_module( inputs: Sequence[Input], settings: CompilationSettings = CompilationSettings(), name: str = "", + engine_cache: Optional[BaseEngineCache] = None, ) -> PythonTorchTensorRTModule | TorchTensorRTModule: """Convert an FX module to a TRT module Args: @@ -129,35 +134,13 @@ def convert_module( inputs: Sequence of Tensors representing inputs to the module settings: Compilation settings name: TRT engine name + engine_cache: Engine cache instance Returns: PythonTorchTensorRTModule or TorchTensorRTModule """ - interpreter_result = interpret_module_to_result(module, inputs, settings) - # Test fast refit: - from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm - from torch_tensorrt.logging import TRT_LOGGER - - weight_name_map: Any = None - # Do the test refit with cached map if make_refitable is enabled - if settings.make_refitable: - runtime = trt.Runtime(TRT_LOGGER) - refit_test_engine = runtime.deserialize_cuda_engine( - interpreter_result.serialized_engine - ) - try: - _refit_single_trt_engine_with_gm( - new_gm=module, - old_engine=refit_test_engine, - input_list=inputs, - settings=settings, - weight_name_map=interpreter_result.weight_name_map, - ) - weight_name_map = interpreter_result.weight_name_map - except AssertionError: - logger.warning("Fast refit test failed. Removing the weight map caching.") - - del refit_test_engine - torch.cuda.empty_cache() + interpreter_result = interpret_module_to_result( + module, inputs, settings, engine_cache=engine_cache + ) rt_cls = PythonTorchTensorRTModule @@ -181,5 +164,5 @@ def convert_module( output_binding_names=list(interpreter_result.output_names), name=name, settings=settings, - weight_name_map=weight_name_map, + weight_name_map=interpreter_result.weight_name_map, ) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index dfd22e7f9f..66192d59a0 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -3,7 +3,7 @@ import logging from dataclasses import fields, replace from enum import Enum -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np import tensorrt as trt @@ -13,6 +13,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from packaging import version @@ -438,7 +439,9 @@ def to_torch_tensorrt_device( return Device._from(device) -def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: +def parse_dynamo_kwargs( + kwargs: Any, +) -> Tuple[CompilationSettings, Optional[BaseEngineCache]]: """Parses the kwargs field of a Dynamo backend Args: @@ -495,9 +498,30 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: ) settings.require_full_compilation = False + # If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided, + # then create a default disk engine cache + engine_cache = None + if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"): + assert kwargs.get( + "make_refitable" + ), "Engine caching requires make_refitable to be set to True" + + if kwargs.get("custom_engine_cache") is not None: + engine_cache = kwargs.get("custom_engine_cache") + else: + from torch_tensorrt.dynamo._engine_caching import DiskEngineCache + + engine_cache_dir = kwargs.get( + "engine_cache_dir", _defaults.ENGINE_CACHE_DIR + ) + engine_cache_size = kwargs.get( + "engine_cache_size", _defaults.ENGINE_CACHE_SIZE + ) + engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size) + logger.info("Compilation Settings: %s\n", settings) - return settings + return settings, engine_cache def req_torch_version(min_torch_version: str = "2.dev") -> Callable[..., Any]: diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py index a29a8061db..c42fd2e61f 100644 --- a/tests/py/dynamo/conversion/test_bitwise_and_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -141,7 +141,12 @@ def forward(self, lhs_val, rhs_val): mod, inputs, dynamic_shapes=({1: dyn_dim}, {0: dyn_dim}) ) trt_mod = torch_tensorrt.dynamo.compile( - fx_mod, inputs=inputs, enable_precisions={torch.bool}, min_block_size=1 + fx_mod, + inputs=inputs, + enable_precisions={torch.bool}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, ) with torch.no_grad(): cuda_inputs = [] diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index d935134ff2..3fef3d70cf 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -484,7 +484,12 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): dynamic_shapes["per_sample_weights"] = {} fx_mod = torch.export.export(mod, inputs, dynamic_shapes=dynamic_shapes) trt_mod = torch_tensorrt.dynamo.compile( - fx_mod, inputs=inputs, enable_precisions=torch.float32, min_block_size=1 + fx_mod, + inputs=inputs, + enable_precisions=torch.float32, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, ) # use the inputs with different shape to inference: if per_sample_weights is None: diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py index 3d0b41b791..b1339efdcf 100644 --- a/tests/py/dynamo/conversion/test_index_select_aten.py +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -109,7 +109,12 @@ def forward(self, source_tensor, indice_tensor): fx_mod = torch.export.export(mod, inputs, dynamic_shapes=dynamic_shapes) trt_mod = torch_tensorrt.dynamo.compile( - fx_mod, inputs=inputs, enable_precisions=torch.float32, min_block_size=1 + fx_mod, + inputs=inputs, + enable_precisions=torch.float32, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, ) # use different shape of inputs for inference: inputs = (source_tensor_1, indice_tensor) diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 29faf4eff3..b486784e52 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -41,6 +41,8 @@ def forward(self, x): truncate_double=True, min_block_size=1, use_python_runtime=False, + cache_built_engines=False, + reuse_cached_engines=False, ) torch_model_results = mod(in_tensor) @@ -79,6 +81,8 @@ def forward(self, x): truncate_double=True, min_block_size=1, use_python_runtime=True, + cache_built_engines=False, + reuse_cached_engines=False, ) torch_model_results = mod(in_tensor) @@ -123,6 +127,8 @@ def forward(self, x): truncate_double=False, min_block_size=1, use_python_runtime=False, + cache_built_engines=False, + reuse_cached_engines=False, ) torch_model_results = mod(in_tensor) @@ -162,6 +168,8 @@ def forward(self, x): truncate_double=False, min_block_size=1, use_python_runtime=True, + cache_built_engines=False, + reuse_cached_engines=False, ) torch_model_results = mod(in_tensor) @@ -214,6 +222,8 @@ def forward(self, x): enabled_precisions={torch.float, torch.bfloat16, torch.half}, min_block_size=1, use_python_runtime=False, + cache_built_engines=False, + reuse_cached_engines=False, ) torch_model_results = mod(in_tensor) @@ -252,6 +262,8 @@ def forward(self, x): enabled_precisions={torch.float, torch.bfloat16, torch.half}, min_block_size=1, use_python_runtime=True, + cache_built_engines=False, + reuse_cached_engines=False, ) torch_model_results = mod(in_tensor) @@ -289,6 +301,8 @@ def forward(self, x): debug=True, min_block_size=1, device=device, + cache_built_engines=False, + reuse_cached_engines=False, ) torch_model_results = mod(*inputs) diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 67eaddcc6c..d5627499f5 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -36,6 +36,8 @@ def forward(self, x): "ir": ir, "pass_through_build_failures": True, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } if ir == "torch_compile": input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda") @@ -90,6 +92,8 @@ def forward(self, x): "pass_through_build_failures": True, "torch_executed_ops": {"torch.ops.aten.abs.default"}, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } if ir == "torch_compile": @@ -141,6 +145,8 @@ def forward(self, x): "ir": ir, "pass_through_build_failures": True, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } if ir == "torch_compile": @@ -184,6 +190,8 @@ def test_resnet_dynamic(ir): "ir": ir, "pass_through_build_failures": True, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } if ir == "torch_compile": @@ -246,6 +254,8 @@ def forward(self, x): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -278,6 +288,8 @@ def forward(self, x): "enabled_precisions": {torch.float}, "ir": ir, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } inputs_bs2 = torch.randn(2, 2, 10).to("cuda") if ir == "torch_compile": @@ -332,6 +344,8 @@ def forward(self, x): "pass_through_build_failures": True, "min_block_size": 1, "torch_executed_ops": {"torch.ops.aten.add.Tensor"}, + "cache_built_engines": False, + "reuse_cached_engines": False, } # Compile the model diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py new file mode 100644 index 0000000000..189a492d4e --- /dev/null +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -0,0 +1,313 @@ +# type: ignore +import os +import shutil +import unittest +from typing import Optional + +import pytest +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class MyEngineCache(BaseEngineCache): + def __init__( + self, + engine_cache_dir: str, + ) -> None: + self.engine_cache_dir = engine_cache_dir + if not os.path.exists(self.engine_cache_dir): + os.makedirs(self.engine_cache_dir, exist_ok=True) + + def save( + self, + hash: str, + blob: bytes, + prefix: str = "blob", + ): + if not os.path.exists(self.engine_cache_dir): + os.makedirs(self.engine_cache_dir, exist_ok=True) + + path = os.path.join( + self.engine_cache_dir, + f"{prefix}_{hash}.bin", + ) + with open(path, "wb") as f: + f.write(blob) + + def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]: + path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin") + if os.path.exists(path): + with open(path, "rb") as f: + blob = f.read() + return blob + return None + + +class TestEngineCache(TestCase): + + def test_dynamo_compile_with_default_disk_engine_cache(self): + model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + engine_cache_dir = ENGINE_CACHE_DIR + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=False, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refitable=True, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + ) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + results.append(trt_gm(*inputs)) + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_dynamo_compile_with_custom_engine_cache(self): + model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/your_dir" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + custom_engine_cache = MyEngineCache(engine_cache_dir) + + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=False, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refitable=True, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + custom_engine_cache=custom_engine_cache, + ) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + results.append(trt_gm(*inputs)) + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_torch_compile_with_default_disk_engine_cache(self): + # Custom Engine Cache + model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/test_torch_compile_with_default_disk_engine_cache" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + # remove timing cache and reset dynamo for engine caching messurement + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + compiled_model = torch.compile( + model, + backend="tensorrt", + options={ + "use_python_runtime": True, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "make_refitable": True, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "engine_cache_dir": engine_cache_dir, + "engine_cache_size": 1 << 30, # 1GB + }, + ) + results.append(compiled_model(*inputs)) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_torch_compile_with_custom_engine_cache(self): + # Custom Engine Cache + model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/your_dir" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + custom_engine_cache = MyEngineCache(engine_cache_dir) + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + # remove timing cache and reset dynamo for engine caching messurement + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + compiled_model = torch.compile( + model, + backend="tensorrt", + options={ + "use_python_runtime": True, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "make_refitable": True, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "custom_engine_cache": custom_engine_cache, + }, + ) + results.append(compiled_model(*inputs)) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) diff --git a/tests/py/dynamo/models/test_export_kwargs_serde.py b/tests/py/dynamo/models/test_export_kwargs_serde.py index 08b23d55e0..52a927e518 100644 --- a/tests/py/dynamo/models/test_export_kwargs_serde.py +++ b/tests/py/dynamo/models/test_export_kwargs_serde.py @@ -63,6 +63,8 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs) @@ -122,6 +124,8 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -190,6 +194,8 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -271,6 +277,8 @@ def forward(self, x, b=None, c=None, d=None, e=[]): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -358,6 +366,8 @@ def forward(self, x, b=None, c=None, d=None, e=[]): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -444,6 +454,8 @@ def forward(self, x, b=None, c=None, d=None, e=[]): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -505,6 +517,8 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index c0c0ba0f22..146cc2addf 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -42,6 +42,8 @@ def forward(self, x): ], "ir": ir, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -94,6 +96,8 @@ def forward(self, x): ], "ir": ir, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -150,6 +154,8 @@ def forward(self, x): ) ], "ir": ir, + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -209,6 +215,8 @@ def forward(self, x): "ir": ir, "min_block_size": 1, "torch_executed_ops": {"torch.ops.aten.relu.default"}, + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -250,6 +258,8 @@ def test_resnet18(ir): ], "ir": ir, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -293,6 +303,8 @@ def test_resnet18_dynamic(ir): ], "ir": ir, "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -340,6 +352,8 @@ def forward(self, x): "ir": ir, "min_block_size": 1, "torch_executed_ops": {"torch.ops.aten.convolution.default"}, + "cache_built_engines": False, + "reuse_cached_engines": False, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) @@ -388,7 +402,14 @@ def forward(self, x): model = MyModule().eval().cuda() input = torch.randn((1, 3, 224, 224)).to("cuda") - trt_gm = torchtrt.compile(model, ir=ir, inputs=[input], min_block_size=1) + trt_gm = torchtrt.compile( + model, + ir=ir, + inputs=[input], + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) assertions.assertTrue( isinstance(trt_gm, torch.fx.GraphModule), msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule", diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 2d45af2b49..ba6cb0c776 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -30,6 +30,8 @@ def test_resnet18(ir): "pass_through_build_failures": True, "optimization_level": 1, "ir": "torch_compile", + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -61,6 +63,8 @@ def test_mobilenet_v2(ir): "optimization_level": 1, "min_block_size": 10, "ir": "torch_compile", + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -92,6 +96,8 @@ def test_efficientnet_b0(ir): "optimization_level": 1, "min_block_size": 10, "ir": "torch_compile", + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -132,6 +138,8 @@ def test_bert_base_uncased(ir): "optimization_level": 1, "min_block_size": 15, "ir": "torch_compile", + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -166,6 +174,8 @@ def test_resnet18_half(ir): "pass_through_build_failures": True, "optimization_level": 1, "ir": "torch_compile", + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index df71d6b58a..bf19c3c5e6 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -31,6 +31,8 @@ def test_resnet18(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -61,6 +63,8 @@ def test_mobilenet_v2(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -91,6 +95,8 @@ def test_efficientnet_b0(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -130,6 +136,8 @@ def test_bert_base_uncased(ir): "truncate_double": True, "ir": ir, "min_block_size": 10, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) model_outputs = model(input, input2) @@ -168,6 +176,8 @@ def test_resnet18_half(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -223,6 +233,8 @@ def calibrate_loop(model): enabled_precisions={torch.float8_e4m3fn}, min_block_size=1, debug=True, + cache_built_engines=False, + reuse_cached_engines=False, ) outputs_trt = trt_model(input_tensor) assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2) @@ -272,6 +284,8 @@ def calibrate_loop(model): enabled_precisions={torch.int8}, min_block_size=1, debug=True, + cache_built_engines=False, + reuse_cached_engines=False, ) outputs_trt = trt_model(input_tensor) assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2) diff --git a/tests/py/dynamo/runtime/test_001_streams.py b/tests/py/dynamo/runtime/test_001_streams.py index 574db6611e..aaec9e3d41 100644 --- a/tests/py/dynamo/runtime/test_001_streams.py +++ b/tests/py/dynamo/runtime/test_001_streams.py @@ -31,6 +31,8 @@ def forward(self, x): enabled_precisions={dtype}, min_block_size=1, device=device, + cache_built_engines=False, + reuse_cached_engines=False, ) for i in range(100): diff --git a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py index 1f3de69eb3..008b0f53b1 100644 --- a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py +++ b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py @@ -160,6 +160,8 @@ def test_lazy_engine_init_py_e2e(self): "ir": "dynamo", "lazy_engine_init": True, "use_python_runtime": True, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -194,6 +196,8 @@ def test_lazy_engine_init_cpp_e2e(self): "ir": "dynamo", "lazy_engine_init": True, "use_python_runtime": False, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -228,6 +232,8 @@ def test_lazy_engine_init_cpp_serialization(self): "ir": "dynamo", "lazy_engine_init": True, "use_python_runtime": False, + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -276,6 +282,8 @@ def forward(self, a, b): "lazy_engine_init": True, "use_python_runtime": True, "torch_executed_ops": [torch.ops.aten.sub.Tensor], + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec) @@ -318,6 +326,8 @@ def forward(self, a, b): "lazy_engine_init": True, "use_python_runtime": False, "torch_executed_ops": [torch.ops.aten.sub.Tensor], + "cache_built_engines": False, + "reuse_cached_engines": False, } trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec)