diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 08cf78ec5f..90c379e4b0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -86,8 +86,8 @@ def compile( lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, - engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR, - engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE, + engine_cache_dir: Optional[str] = _defaults.ENGINE_CACHE_DIR, + engine_cache_size: Optional[int] = _defaults.ENGINE_CACHE_SIZE, custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, **kwargs: Any, ) -> torch.fx.GraphModule: @@ -156,8 +156,8 @@ def compile( lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage - engine_cache_dir (str): Directory to store the cached TRT engines - engine_cache_size (int): Maximum hard-disk space to use for the engine cache + engine_cache_dir (Optional[str]): Directory to store the cached TRT engines + engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. **kwargs: Any, Returns: @@ -235,12 +235,16 @@ def compile( gm = post_lowering(gm) logger.debug("Lowered Input graph: " + str(gm.graph)) + engine_cache = None if cache_built_engines or reuse_cached_engines: assert ( make_refitable ), "Engine caching requires make_refitable to be set to True" - if custom_engine_cache is None: - custom_engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size) + engine_cache = ( + custom_engine_cache + if custom_engine_cache is not None + else DiskEngineCache(engine_cache_dir, engine_cache_size) + ) compilation_options = { "enabled_precisions": ( @@ -277,12 +281,13 @@ def compile( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, - "custom_engine_cache": custom_engine_cache, } settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings) + trt_gm = compile_module( + gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache + ) return trt_gm @@ -291,6 +296,7 @@ def compile_module( sample_arg_inputs: Sequence[Input], sample_kwarg_inputs: Optional[dict[Any, Any]] = None, settings: CompilationSettings = CompilationSettings(), + engine_cache: Optional[BaseEngineCache] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -301,6 +307,7 @@ def compile_module( arg_inputs: Inputs to the module kwarg_inputs: kwargs to the module settings: Compilation settings + engine_cache: Engine cache instance to store/load compiled engines Returns: Compiled FX GraphModule """ @@ -480,6 +487,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: submodule_inputs, settings=settings, name=name, + engine_cache=engine_cache, ) trt_modules[name] = trt_module diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 0327727c9f..063f6f3718 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,7 +7,6 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, - CUSTOM_ENGINE_CACHE, DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -36,7 +35,6 @@ WORKSPACE_SIZE, default_device, ) -from torch_tensorrt.dynamo._engine_caching import BaseEngineCache @dataclass @@ -80,7 +78,6 @@ class CompilationSettings: timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage - custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -115,4 +112,3 @@ class CompilationSettings: lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES - custom_engine_cache: Optional[BaseEngineCache] = CUSTOM_ENGINE_CACHE diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ae3cb38f2d..605d963a50 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -48,14 +48,15 @@ def torch_tensorrt_backend( def aot_torch_tensorrt_aten_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any ) -> torch.nn.Module: - settings = parse_dynamo_kwargs(kwargs) - return _pretraced_backend(gm, sample_inputs, settings) + settings, engine_cache = parse_dynamo_kwargs(kwargs) + return _pretraced_backend(gm, sample_inputs, settings, engine_cache) def _pretraced_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], settings: CompilationSettings = CompilationSettings(), + engine_cache: Any = None, ) -> torch.fx.GraphModule | Callable[..., Any]: """Helper function to manage translation of traced FX module to TRT engines @@ -63,6 +64,7 @@ def _pretraced_backend( module: FX GraphModule to convert inputs: Inputs to the module settings: Compilation settings + engine_cache: Engine cache instance Returns: Compiled FX GraphModule """ @@ -109,6 +111,7 @@ def _pretraced_backend( gm, torchtrt_inputs, settings=settings, + engine_cache=engine_cache, ) return trt_compiled except (AssertionError, RuntimeError): diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 16a1e0c75b..22743af0aa 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -27,6 +27,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -71,6 +72,7 @@ def __init__( logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, output_dtypes: Optional[Sequence[dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), + engine_cache: Optional[BaseEngineCache] = None, ): super().__init__(module) @@ -126,6 +128,9 @@ def __init__( self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {} self.weight_name_map: Optional[dict[str, Any]] = None + # Engine cache for storing and reusing TRT engines + self.engine_cache = engine_cache + def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -521,22 +526,22 @@ def run( Return: TRTInterpreterResult """ - if ( - self.compilation_settings.custom_engine_cache is not None - ): # custom_engine_cache could be None if this function is called from convert_exported_program_to_serialized_trt_engine etc. + # self.engine_cache could be None if: + # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or + # 2) both cache_built_engines and reuse_cached_engines are False + if self.engine_cache is not None: if ( self.compilation_settings.cache_built_engines or self.compilation_settings.reuse_cached_engines ): - engine_cache = self.compilation_settings.custom_engine_cache - hash_val = engine_cache.get_hash(self.module) + hash_val = self.engine_cache.get_hash(self.module) if self.compilation_settings.reuse_cached_engines: # query the cached TRT engine - blob = engine_cache.load(hash_val) + blob = self.engine_cache.load(hash_val) if blob is not None: # hit the cache serialized_engine, input_names, output_names, weight_name_map = ( - engine_cache.unpack(blob) + self.engine_cache.unpack(blob) ) self._input_names = input_names self._output_names = output_names @@ -605,16 +610,16 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) if ( - self.compilation_settings.custom_engine_cache is not None + self.engine_cache is not None and self.compilation_settings.cache_built_engines ): - blob = engine_cache.pack( + blob = self.engine_cache.pack( serialized_engine, self._input_names, self._output_names, self.weight_name_map, ) - engine_cache.save(hash_val, blob) + self.engine_cache.save(hash_val, blob) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 2041dd1c37..97359c65a7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -10,6 +10,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( TRTInterpreter, @@ -70,6 +71,7 @@ def interpret_module_to_result( settings: CompilationSettings = CompilationSettings(), arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, + engine_cache: Optional[BaseEngineCache] = None, ) -> TRTInterpreterResult: """Interpret an FX module to a TRTInterpreterResult Args: @@ -79,6 +81,7 @@ def interpret_module_to_result( arg_inputs: Sequence of Tensors representing inputs to the module. kwarg_inputs: A dictionary of Tensors representing inputs to the module. settings: Compilation settings + engine_cache: Engine cache instance Returns: TRTInterpreterResult """ @@ -105,6 +108,7 @@ def interpret_module_to_result( logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, compilation_settings=settings, + engine_cache=engine_cache, ) interpreter_result = interpreter.run() return interpreter_result @@ -115,6 +119,7 @@ def convert_module( inputs: Sequence[Input], settings: CompilationSettings = CompilationSettings(), name: str = "", + engine_cache: Optional[BaseEngineCache] = None, ) -> PythonTorchTensorRTModule | TorchTensorRTModule: """Convert an FX module to a TRT module Args: @@ -122,10 +127,13 @@ def convert_module( inputs: Sequence of Tensors representing inputs to the module settings: Compilation settings name: TRT engine name + engine_cache: Engine cache instance Returns: PythonTorchTensorRTModule or TorchTensorRTModule """ - interpreter_result = interpret_module_to_result(module, inputs, settings) + interpreter_result = interpret_module_to_result( + module, inputs, settings, engine_cache=engine_cache + ) rt_cls = PythonTorchTensorRTModule diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 968bdf9858..41ff1680ad 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -3,7 +3,7 @@ import logging from dataclasses import fields, replace from enum import Enum -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np import tensorrt as trt @@ -12,6 +12,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from packaging import version @@ -301,7 +302,9 @@ def to_torch_tensorrt_device( return Device._from(device) -def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: +def parse_dynamo_kwargs( + kwargs: Any, +) -> Tuple[CompilationSettings, Optional[BaseEngineCache]]: """Parses the kwargs field of a Dynamo backend Args: @@ -360,11 +363,15 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: # If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided, # then create a default disk engine cache + engine_cache = None if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"): assert kwargs.get( "make_refitable" ), "Engine caching requires make_refitable to be set to True" - if settings.custom_engine_cache is None: + + if kwargs.get("custom_engine_cache") is not None: + engine_cache = kwargs.get("custom_engine_cache") + else: from torch_tensorrt.dynamo._engine_caching import DiskEngineCache engine_cache_dir = kwargs.get( @@ -373,13 +380,11 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: engine_cache_size = kwargs.get( "engine_cache_size", _defaults.ENGINE_CACHE_SIZE ) - settings.custom_engine_cache = DiskEngineCache( - engine_cache_dir, engine_cache_size - ) + engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size) logger.info("Compilation Settings: %s\n", settings) - return settings + return settings, engine_cache def req_torch_version(min_torch_version: str = "2.dev") -> Callable[..., Any]: diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 7b6247ced9..1a5b874eb4 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -49,7 +49,7 @@ def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]: class TestEngineCache(TestCase): - def test_dynamo_compile(self): + def test_dynamo_compile_with_default_disk_engine_cache(self): model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) # Mark the dim0 of inputs as dynamic @@ -57,15 +57,87 @@ def test_dynamo_compile(self): exp_program = torch.export.export( model, args=example_inputs, dynamic_shapes={"x": {0: batch}} ) + engine_cache_dir = ENGINE_CACHE_DIR if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=False, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refitable=True, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + ) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + results.append(trt_gm(*inputs)) + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_dynamo_compile_with_default_disk_engine_cache: results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_dynamo_compile_with_default_disk_engine_cache: results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"test_dynamo_compile_with_default_disk_engine_cache: Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_dynamo_compile_with_custom_engine_cache(self): + model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/your_dir" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + custom_engine_cache = MyEngineCache(engine_cache_dir) + + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + # The 1st iteration is to measure the compilation time without engine caching # The 2nd and 3rd iterations are to measure the compilation time with engine caching. # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) for i in range(3): if i == 0: cache_built_engines = False @@ -74,6 +146,7 @@ def test_dynamo_compile(self): cache_built_engines = True reuse_cached_engines = True + start.record() trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), @@ -84,23 +157,95 @@ def test_dynamo_compile(self): make_refitable=True, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, - engine_cache_size=1 << 30, # 1GB + custom_engine_cache=custom_engine_cache, ) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) results.append(trt_gm(*inputs)) cos_sim = cosine_similarity(results[0], results[1]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_dynamo_compile TRT without engine caching doesn't match with that with engine caching. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_dynamo_compile_with_custom_engine_cache: results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) cos_sim = cosine_similarity(results[1], results[2]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_dynamo_compile TRT with engine caching doesn't match with that cached engine. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_dynamo_compile_with_custom_engine_cache: results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - def test_torch_compile(self): + assertions.assertTrue( + times[0] > times[2], + msg=f"test_dynamo_compile_with_custom_engine_cache: Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_torch_compile_with_default_disk_engine_cache(self): + # Custom Engine Cache + model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/test_torch_compile_with_default_disk_engine_cache" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + # remove timing cache and reset dynamo for engine caching messurement + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + compiled_model = torch.compile( + model, + backend="tensorrt", + options={ + "use_python_runtime": True, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "make_refitable": True, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "engine_cache_dir": engine_cache_dir, + "engine_cache_size": 1 << 30, # 1GB + }, + ) + results.append(compiled_model(*inputs)) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_torch_compile_with_default_disk_engine_cache: results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_torch_compile_with_default_disk_engine_cache: results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"test_torch_compile_with_default_disk_engine_cache: Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_torch_compile_with_custom_engine_cache(self): # Custom Engine Cache model = models.resnet18(pretrained=True).eval().to("cuda") @@ -108,13 +253,16 @@ def test_torch_compile(self): if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) - engine_cache = MyEngineCache(engine_cache_dir) + custom_engine_cache = MyEngineCache(engine_cache_dir) # The 1st iteration is to measure the compilation time without engine caching # The 2nd and 3rd iterations are to measure the compilation time with engine caching. # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) for i in range(3): # remove timing cache and reset dynamo for engine caching messurement if i == 0: @@ -124,6 +272,7 @@ def test_torch_compile(self): cache_built_engines = True reuse_cached_engines = True + start.record() compiled_model = torch.compile( model, backend="tensorrt", @@ -135,19 +284,27 @@ def test_torch_compile(self): "make_refitable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, - "custom_engine_cache": engine_cache, # use custom engine cache + "custom_engine_cache": custom_engine_cache, }, ) results.append(compiled_model(*inputs)) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) cos_sim = cosine_similarity(results[0], results[1]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_torch_compile TRT without engine caching doesn't match with that with engine caching. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_torch_compile_with_custom_engine_cache: results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) cos_sim = cosine_similarity(results[1], results[2]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_torch_compile TRT with engine caching doesn't match with that cached engine. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_torch_compile_with_custom_engine_cache: results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"test_torch_compile_with_custom_engine_cache: Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", )