diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 96e5f313ae..3b9c2df2e0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -88,6 +88,8 @@ def compile( engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR, engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE, custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -158,6 +160,8 @@ def compile( engine_cache_dir (Optional[str]): Directory to store the cached TRT engines engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -281,6 +285,8 @@ def compile( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 68e446dab5..982e48c1d5 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -40,6 +40,8 @@ ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") ENGINE_CACHE_SIZE = 1073741824 CUSTOM_ENGINE_CACHE = None +REFIT_IDENTICAL_ENGINE_WEIGHTS = False +STRIP_ENGINE_WEIGHTS = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index f8886fbd67..e4458b187b 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -24,9 +24,11 @@ NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, + REFIT_IDENTICAL_ENGINE_WEIGHTS, REQUIRE_FULL_COMPILATION, REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, + STRIP_ENGINE_WEIGHTS, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, USE_FAST_PARTITIONER, @@ -78,6 +80,8 @@ class CompilationSettings: timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage + refit_identical_engine_weights (bool): Whether to refit the engine with identical weights + strip_engine_weights (bool): Whether to strip the engine weights """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -112,6 +116,8 @@ class CompilationSettings: lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES + refit_identical_engine_weights: bool = REFIT_IDENTICAL_ENGINE_WEIGHTS + strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS _SETTINGS_TO_BE_ENGINE_INVARIANT = ( @@ -124,6 +130,8 @@ class CompilationSettings: "make_refittable", "engine_capability", "hardware_compatible", + "refit_identical_engine_weights", + "strip_engine_weights", ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index ff35bf39d7..63dd90a212 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -18,6 +18,7 @@ ) import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -43,7 +44,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -283,7 +283,16 @@ def _populate_trt_builder_config( builder_config.clear_flag(trt.BuilderFlag.TF32) if self.compilation_settings.make_refittable: - builder_config.set_flag(trt.BuilderFlag.REFIT) + if version.parse(trt.__version__) >= version.parse("10.0"): + if self.compilation_settings.refit_identical_engine_weights: + builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) + + if self.compilation_settings.strip_engine_weights: + builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) @@ -542,7 +551,7 @@ def run( cached_data = self.engine_cache.check(hash_val) if cached_data is not None: # hit the cache ( - serialized_engine, + unrefitted_serialized_engine, self._input_names, self._output_names, cached_engine_input_specs, @@ -573,31 +582,12 @@ def run( "Found the cached engine that corresponds to this graph. It is directly loaded." ) - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - # TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers. - # We set weight_name_map=None to use slow refit anyway for now. Will fix it in the future. - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=None, - ) - - serialized_engine = engine.serialize() - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + engine_bytes.write(unrefitted_serialized_engine) + unrefitted_engine_str = engine_bytes.getvalue() return TRTInterpreterResult( - engine_str, + unrefitted_engine_str, self._input_names, self._output_names, self.weight_name_map, @@ -619,19 +609,24 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( + # if strip_engine_weights is true, the serialized engine need to be refitted before using + maybe_unrefitted_serialized_engine = self.builder.build_serialized_network( self.ctx.net, builder_config ) - assert serialized_engine + assert maybe_unrefitted_serialized_engine _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") + _LOGGER.info( + f"TRT Engine uses: {maybe_unrefitted_serialized_engine.nbytes} bytes of Memory" + ) self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) + + # if strip_engine_weights is true, the weight-stripped engine will be saved in engine cache if ( self.engine_cache is not None and self.compilation_settings.cache_built_engines @@ -639,7 +634,7 @@ def run( self.engine_cache.insert( hash_val, ( - serialized_engine, + maybe_unrefitted_serialized_engine, self._input_names, self._output_names, self.input_specs, @@ -649,11 +644,14 @@ def run( ) with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + engine_bytes.write(maybe_unrefitted_serialized_engine) + maybe_unrefitted_engine_str = engine_bytes.getvalue() return TRTInterpreterResult( - engine_str, self._input_names, self._output_names, self.weight_name_map + maybe_unrefitted_engine_str, + self._input_names, + self._output_names, + self.weight_name_map, ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index f0b65b3a6e..aa7ff05cc8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -3,6 +3,7 @@ import logging from typing import Any, List, Optional, Sequence +import tensorrt as trt import torch from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device @@ -18,8 +19,6 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -166,4 +165,5 @@ def convert_module( name=name, settings=settings, weight_name_map=interpreter_result.weight_name_map, + graph_module=module, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f74c239550..7852c85b9f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -5,13 +5,14 @@ from tempfile import tempdir from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( _is_switch_required, @@ -19,8 +20,6 @@ multi_gpu_device_check, ) -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -39,7 +38,8 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), - weight_name_map: Any = None, + weight_name_map: Optional[dict[Any, Any]] = None, + graph_module: torch.fx.GraphModule = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -52,6 +52,8 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name + graph_module (torch.fx.GraphModule): GraphModule used to refit the weights Example: @@ -106,6 +108,7 @@ def __init__( self.settings = settings self.engine = None self.weight_name_map = weight_name_map + self.graph_module = graph_module # may be used to refit the weights self.target_platform = Platform.current_platform() if self.serialized_engine is not None and not self.settings.lazy_engine_init: @@ -121,6 +124,52 @@ def setup_engine(self) -> None: self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) self.context = self.engine.create_execution_context() + if self.settings.strip_engine_weights: + assert ( + self.settings.make_refittable + ), "weight-stripped engines must be refittable, please set make_refittable=True" + + # Refit the weights + refitter = trt.Refitter(self.engine, TRT_LOGGER) + refittable_weights = refitter.get_all_weights() + torch_device = get_model_device(self.graph_module) + + for layer_name in refittable_weights: + trt_wt_location = ( + trt.TensorLocation.DEVICE + if torch_device.type == "cuda" + else trt.TensorLocation.HOST + ) + from torch_tensorrt.dynamo._refit import ( + construct_refit_mapping_from_weight_name_map, + ) + + mapping = construct_refit_mapping_from_weight_name_map( + self.weight_name_map, self.graph_module.state_dict() + ) + + for layer_name in refittable_weights: + if layer_name not in mapping: + logger.warning(f"{layer_name} is not found in weight mapping.") + continue + # Use Numpy to create weights + weight, weight_dtype = mapping[layer_name] + trt_wt_tensor = trt.Weights( + weight_dtype, weight.data_ptr(), torch.numel(weight) + ) + refitter.set_named_weights( + layer_name, trt_wt_tensor, trt_wt_location + ) + assert ( + len(refitter.get_missing_weights()) == 0 + ), "Fast refitting failed due to incomplete mapping" + + # Refit the engine + if refitter.refit_cuda_engine(): + logger.info("Engine refitted successfully!") + else: + logger.info("Engine refit failed!") + assert self.engine.num_io_tensors == ( len(self.input_names) + len(self.output_names) ) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 7bf42da7f0..ccfbad352e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -79,6 +79,7 @@ def __init__( name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed weight_name_map: Optional[dict[Any, Any]] = None, + graph_module: torch.fx.GraphModule = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -96,6 +97,8 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name + graph_module (torch.fx.GraphModule): GraphModule used to refit the weights Example: @@ -129,6 +132,7 @@ def __init__( self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) self.weight_name_map = weight_name_map + self.graph_module = graph_module self.serialized_engine = serialized_engine self.engine = None diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 367f68c1f6..cd720bc030 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -206,6 +206,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i in range(3): + # remove timing cache and reset dynamo for engine caching messurement remove_timing_cache() torch._dynamo.reset() if i == 0: @@ -220,7 +221,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -231,7 +232,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): ) end.record() torch.cuda.synchronize() - torch._dynamo.reset() times.append(start.elapsed_time(end)) results.append(trt_gm(*inputs)) @@ -285,7 +285,7 @@ def test_dynamo_compile_with_custom_engine_cache(self): trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -332,7 +332,7 @@ def test_dynamo_compile_change_input_shape(self): trt_gm = torch_trt.dynamo.compile( torch.export.export(model, args=inputs), inputs=inputs, - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -402,7 +402,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): results.append(compiled_model(*inputs)) # trigger the compilation end.record() torch.cuda.synchronize() - torch._dynamo.reset() times.append(start.elapsed_time(end)) cos_sim = cosine_similarity(results[0], results[1]) @@ -441,7 +440,6 @@ def test_torch_compile_with_custom_engine_cache(self): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i in range(3): - # remove timing cache and reset dynamo for engine caching messurement if i == 0: cache_built_engines = False reuse_cached_engines = False @@ -501,7 +499,6 @@ def test_torch_compile_change_input_shape(self): custom_engine_cache = MyEngineCache(engine_cache_dir) for i in range(3): - # remove timing cache and reset dynamo for engine caching messurement inputs = [torch.rand((4 * (i + 1), 3, 224, 224)).to("cuda")] compiled_model = torch.compile( model, diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py new file mode 100644 index 0000000000..4cef33e082 --- /dev/null +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -0,0 +1,143 @@ +import os +import shutil +import unittest + +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class TestEngineCache(TestCase): + def test_weight_stripped_engine(self): + model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + engine_cache_dir = "/tmp/test_weight_stripped_engine" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + + # run pytorch model + results.append(model(*inputs)) + + remove_timing_cache() + torch._dynamo.reset() + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + refit_identical_engine_weights=False, + cache_built_engines=False, + reuse_cached_engines=False, + engine_cache_dir=engine_cache_dir, + strip_engine_weights=True, + ) + results.append(trt_gm(*inputs)) + + cos_sim = cosine_similarity(results[0], results[1]) + + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_dynamo_compile_with_refittable_weight_stripped_engine(self): + model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + engine_cache_dir = ( + "/tmp/test_dynamo_compile_with_refittable_weight_stripped_engine" + ) + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + remove_timing_cache() + torch._dynamo.reset() + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + torch.cuda.synchronize() + start.record() + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + refit_identical_engine_weights=True, + strip_engine_weights=True, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir, + ) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + results.append(trt_gm(*inputs)) + + assertions.assertNotEqual(results[0].sum(), 0, msg="results[0] are all zeros") + assertions.assertNotEqual(results[1].sum(), 0, msg="results[1] are all zeros") + assertions.assertNotEqual(results[2].sum(), 0, msg="results[2] are all zeros") + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + )