diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index a97cb528d4..660cb8a875 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -3,7 +3,7 @@ import collections.abc import copy import logging -from typing import Any, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple import numpy as np import tensorrt as trt @@ -13,7 +13,7 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._exporter import inline_torch_modules -from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, @@ -108,38 +108,97 @@ def construct_refit_mapping( return weight_map +def construct_refit_mapping_from_weight_name_map( + weight_name_map: dict[Any, Any], state_dict: dict[Any, Any] +) -> dict[Any, Any]: + engine_weight_map = {} + for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): + trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) + torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]: + # Batch Norm Layer + params = {} + for w in sd_weight_name: + params[w.split(".")[-1]] = state_dict[w] + scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7) + shift = params["bias"] - params["running_mean"] * scale + # Set scale to scale or shift to shift + engine_weight_map[engine_weight_name] = eval( + engine_weight_name.split(" ")[-1].lower() + ) + + elif sd_weight_name not in state_dict: + # If weights is not in sd, we can leave it unchanged + continue + else: + engine_weight_map[engine_weight_name] = state_dict[sd_weight_name] + + engine_weight_map[engine_weight_name] = ( + engine_weight_map[engine_weight_name] + .clone() + .reshape(-1) + .contiguous() + .to(torch_dtype), + trt_dtype, + ) + + return engine_weight_map + + def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, - input_list: Tuple[Any, ...], + input_list: Sequence[Any], settings: CompilationSettings = CompilationSettings(), + weight_name_map: Optional[dict[str, List[str]]] = None, ) -> None: """ Refit a TensorRT Engine in place """ - # Get the refitting mapping - mapping = construct_refit_mapping(new_gm, input_list, settings) + refitted = set() - trt_wt_location = trt.TensorLocation.HOST refitter = trt.Refitter(old_engine, TRT_LOGGER) weight_list = refitter.get_all_weights() - for layer_name in weight_list: - if layer_name not in mapping: - raise AssertionError(f"{layer_name} is not found in weight mapping") - # Use Numpy to create weights - weight, datatype = mapping[layer_name] - trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - refitted.add(layer_name) + if weight_name_map: + # Get the refitting mapping + trt_wt_location = trt.TensorLocation.DEVICE + mapping = construct_refit_mapping_from_weight_name_map( + weight_name_map, new_gm.state_dict() + ) + for layer_name in weight_list: + if layer_name not in mapping: + logger.warning(f"{layer_name} is not found in weight mapping.") + continue + # Use Numpy to create weights + weight, weight_dtype = mapping[layer_name] + trt_wt_tensor = trt.Weights( + weight_dtype, weight.data_ptr(), torch.numel(weight) + ) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + assert ( + len(refitter.get_missing_weights()) == 0 + ), "Fast refitting failed due to incomplete mapping" - if len(refitted) != len(weight_list): - logger.warning("Not all weights have been refitted!!!") + else: + mapping = construct_refit_mapping(new_gm, input_list, settings) + trt_wt_location = trt.TensorLocation.HOST + for layer_name in weight_list: + if layer_name not in mapping: + raise AssertionError(f"{layer_name} is not found in weight mapping") + # Use Numpy to create weights + weight, datatype = mapping[layer_name] + trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + refitted.add(layer_name) + + if len(refitted) != len(weight_list): + logger.warning("Not all weights have been refitted!!!") if not refitter.refit_cuda_engine(): logger.error("Error: failed to refit new weights.") - exit(0) + raise AssertionError("Refitting failed.") def refit_module_weights( @@ -148,6 +207,8 @@ def refit_module_weights( arg_inputs: Optional[Tuple[Any, ...]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, verify_output: bool = False, + use_weight_map_cache: bool = True, + in_place: bool = False, ) -> torch.fx.GraphModule: """ Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine. @@ -170,7 +231,12 @@ def refit_module_weights( if len(list(compiled_module.named_children())) == 0: inline_module = True - compiled_module = copy.deepcopy(compiled_module) + if not in_place: + compiled_module = copy.deepcopy(compiled_module) + elif inline_module: + raise AssertionError( + "Exported program does not support modifying in place. Please set inplace to false and use the returned graph module." + ) # Get the settings and check the setting to be uniform settings: CompilationSettings = None @@ -182,13 +248,14 @@ def refit_module_weights( for name, engine in compiled_module.__dict__.items() if "engine" in name ] - encoded_settings = compiled_submodules[0][1].__getstate__()[0][ + # [('_run_on_acc_0', inline_module)] + encoded_metadata = compiled_submodules[0][1].__getstate__()[0][ SERIALIZED_METADATA_IDX ] assert ( - encoded_settings != "" - ), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True." - settings = TorchTensorRTModule.decode_metadata(encoded_settings) + encoded_metadata != "" + ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True" + settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"] # Handle torch modules compiled_submodules_map = dict(compiled_submodules) for name, submodule in compiled_module.named_children(): @@ -287,6 +354,7 @@ def refit_module_weights( # Extract engine from the submodule try: if inline_module: + weight_name_map = None compiled_submodule = compiled_submodules_map[name] # If this is a torch module, load the old state_dict if "_run_on_acc" not in name: @@ -297,8 +365,33 @@ def refit_module_weights( engine = get_engine_from_encoded_engine( engine_info[ENGINE_IDX], runtime ) + if use_weight_map_cache: + encoded_metadata = compiled_submodule.__getstate__()[0][ + SERIALIZED_METADATA_IDX + ] + weight_name_map = TorchTensorRTModule.decode_metadata( + encoded_metadata + )["weight_name_map"] + if not weight_name_map: + use_weight_map_cache = False + logger.warning( + "This engine does not have a weight map cache. Rebuilding the weight map" + ) else: compiled_submodule = getattr(compiled_module, name) + weight_name_map = None + if use_weight_map_cache: + try: + weight_name_map = compiled_submodule.weight_name_map + except AttributeError: + logger.warning( + "The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map." + ) + if not weight_name_map: + use_weight_map_cache = False + logger.warning( + "This engine does not have a weight map cache. Rebuilding the weight map" + ) if isinstance(compiled_submodule, PythonTorchTensorRTModule): engine = compiled_submodule.engine elif isinstance(compiled_submodule, TorchTensorRTModule): @@ -335,13 +428,25 @@ def refit_module_weights( to_torch_device(settings.device), name, ) - - _refit_single_trt_engine_with_gm( - new_gm=new_submodule, - old_engine=engine, - input_list=submodule_inputs, - settings=settings, - ) + try: + _refit_single_trt_engine_with_gm( + new_gm=new_submodule, + old_engine=engine, + input_list=submodule_inputs, + settings=settings, + weight_name_map=weight_name_map, + ) + except AssertionError as e: + # If fast_refit is used and failed, we fall back to regular refit + logger.warning(e) + if use_weight_map_cache and weight_name_map: + _refit_single_trt_engine_with_gm( + new_gm=new_submodule, + old_engine=engine, + input_list=submodule_inputs, + settings=settings, + weight_name_map=None, + ) if isinstance(compiled_submodule, TorchTensorRTModule): serialized_engine = bytes(engine.serialize()) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 703a650c99..9a3cace599 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -29,7 +30,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -47,6 +47,7 @@ class TRTInterpreterResult(NamedTuple): serialized_engine: bytes input_names: Sequence[str] output_names: Sequence[str] + weight_name_map: Optional[dict[Any, Any]] class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] @@ -110,6 +111,7 @@ def __init__( # Mapping of constants to shapes and dtypes self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {} + self.weight_name_map: Optional[dict[str, Any]] = None def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -320,6 +322,141 @@ def _construct_trt_network_def(self) -> None: f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" ) + def _save_weight_mapping(self) -> None: + """ + Construct the weight name mapping from engine weight name to state_dict weight name. + Cache the weight name for future refitting usecases. + Two-stage weight name tracing: + 1. Name transformation from engine weight name to state_dict weight name + 2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict + """ + + def find_weight( + weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] + ) -> str: + network_weight = np_map[weight_name] + for sd_w_name, sd_weight in sd.items(): + if check_weight_equal(sd_weight, network_weight): + return sd_w_name + return "" + + def check_weight_equal( + sd_weight: torch.tensor, network_weight: np.ndarray + ) -> Any: + sd_weight = sd_weight.reshape(-1).cpu().numpy() + return sd_weight.size == network_weight.size and np.allclose( + sd_weight, network_weight, 1e-1, 1e-1 + ) + + MODULE_MAP = { + "SCALE": ( + trt.IScaleLayer, + [ + ( + "scale", + "SCALE", + ("weight", "bias", "running_mean", "running_var"), + ), + ( + "shift", + "SHIFT", + ("weight", "bias", "running_mean", "running_var"), + ), + ], + ), + "CONVOLUTION": ( + trt.IConvolutionLayer, + [("kernel", "KERNEL", "weight"), ("bias", "BIAS", "bias")], + ), + "DECONVOLUTION": ( + trt.IDeconvolutionLayer, + [("kernel", "KERNEL", "weight"), ("bias", "BIAS", "bias")], + ), + "CONSTANT": ( + trt.IConstantLayer, + [("weights", "CONSTANT", ("weight", "bias"))], + ), + } + """ + The structure of this map is: + { + layer_type: ( + Corresponding ILayer type to cast, + [ + ( + ILayer weight attribute, + Weight name postfix in TRT Engine, + Weight name postfix in state_dict + ), + ... + ] + ) + } + """ + # Stage 1: Name mapping + sd = self.module.state_dict() + weight_name_map: dict[str, Any] = {} + np_map = {} + net = self.ctx.net + for i in range(net.num_layers): + layer = net[i] + layer_type: str = layer.type.name + if layer_type in MODULE_MAP: + layer.__class__ = MODULE_MAP[layer_type][0] + # Name mapping + for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]: + weight = layer.__getattribute__(weight_type).copy() + if weight.size == 0: + continue + engine_weight_name = f"{layer.name} {weight_name}" + # Infer the corresponding weight name(s) in state_dict + sd_weight_name_list = ( + layer.name.split("-")[-1] + .replace("[", "") + .replace("]", "") + .split("/") + ) + sd_weight_name: Any = ".".join( + [i for i in sd_weight_name_list[:-1] if i] + ) + suffix = sd_weight_name_list[-1] + # Retrieve each weight name(s) in state_dict + if layer_type == "CONSTANT": + if "embedding" in suffix: + sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}" + elif "weight" in suffix or "mm_other" in suffix: + # Linear layer weight + sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}" + else: + sd_weight_name = f"{sd_weight_name}.{torch_attr[1]}" + elif layer_type == "SCALE": + # Batch norm needs all weights to calculate scale and shift + sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] + else: + sd_weight_name = f"{sd_weight_name}.{torch_attr}" + + weight_name_map[engine_weight_name] = sd_weight_name + np_map[engine_weight_name] = weight + + # Stage 2: Value mapping + for engine_weight_name, sd_weight_name in weight_name_map.items(): + if "SCALE" in engine_weight_name: + # There is no direct connection in batch_norm layer. So skip it + pass + elif sd_weight_name not in sd or not check_weight_equal( + sd[sd_weight_name], np_map[engine_weight_name] + ): + weight_name_map[engine_weight_name] = find_weight( + engine_weight_name, np_map, sd + ) + + weight_name_map[engine_weight_name] = [ + weight_name_map[engine_weight_name], + np_map[engine_weight_name].dtype, + ] + + self.weight_name_map = weight_name_map + def run( self, strict_type_constraints: bool = False, @@ -335,6 +472,10 @@ def run( TRTInterpreterResult """ self._construct_trt_network_def() + + if self.compilation_settings.make_refitable: + self._save_weight_mapping() + build_engine_start_time = datetime.now() builder_config = self._populate_trt_builder_config( @@ -363,7 +504,9 @@ def run( engine_bytes.write(serialized_engine) engine_str = engine_bytes.getvalue() - return TRTInterpreterResult(engine_str, self._input_names, self._output_names) + return TRTInterpreterResult( + engine_str, self._input_names, self._output_names, self.weight_name_map + ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: self._cur_node_name = get_node_name(n) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index c1663ca5cd..57fa1749bf 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -126,6 +126,28 @@ def convert_module( PythonTorchTensorRTModule or TorchTensorRTModule """ interpreter_result = interpret_module_to_result(module, inputs, settings) + # Test fast refit: + from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm + from torch_tensorrt.logging import TRT_LOGGER + + runtime = trt.Runtime(TRT_LOGGER) + refit_test_engine = runtime.deserialize_cuda_engine( + interpreter_result.serialized_engine + ) + weight_name_map: Any = None + # Do the test refit with cached map if make_refitable is enabled + if settings.make_refitable: + weight_name_map = interpreter_result.weight_name_map + try: + _refit_single_trt_engine_with_gm( + new_gm=module, + old_engine=refit_test_engine, + input_list=inputs, + settings=settings, + weight_name_map=interpreter_result.weight_name_map, + ) + except AssertionError: + logger.warning("Fast refit test failed. Removing the weight map caching.") rt_cls = PythonTorchTensorRTModule @@ -149,4 +171,5 @@ def convert_module( output_binding_names=list(interpreter_result.output_names), name=name, settings=settings, + weight_name_map=weight_name_map, ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index f847091800..af0d6b720a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,7 +1,6 @@ import collections import functools import logging -import re from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np @@ -34,10 +33,7 @@ def get_node_name(node: torch.fx.Node) -> str: mod_stack = stack_item.popitem() if stack_item else "" node_name = str(node) if mod_stack: - mod_name = str(mod_stack[0]).replace("___", "/") - # Clean up the module name - mod_name = re.sub("^.*__self", "", mod_name) - mod_name = re.sub(r"_(\d+)$", r"/\g<1>", mod_name) + mod_name = mod_stack[1][0] node_name = mod_name + "/" + node_name else: # Try an alternative way to get the module info diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 659f18af52..d5da83488a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -4,6 +4,7 @@ from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module @@ -18,8 +19,6 @@ from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -38,6 +37,7 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), + weight_name_map: Any = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -102,6 +102,7 @@ def __init__( self.profiling_enabled = settings.debug if settings.debug is not None else False self.settings = settings self.engine = None + self.weight_name_map = weight_name_map if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 0ab0dd49ca..fe3974ff96 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -58,6 +58,7 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed + weight_name_map: Optional[dict[Any, Any]] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -107,6 +108,7 @@ def __init__( self.name = name self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) + self.weight_name_map = weight_name_map self.serialized_engine = serialized_engine self.engine = None @@ -130,6 +132,7 @@ def setup_engine(self) -> None: if self.settings.device is not None else Device._current_device() ) + metadata = {"settings": self.settings, "weight_name_map": self.weight_name_map} self.engine = torch.classes.tensorrt.Engine( [ torch.ops.tensorrt.ABI_VERSION(), @@ -139,25 +142,28 @@ def setup_engine(self) -> None: TorchTensorRTModule._pack_binding_names(self.input_binding_names), TorchTensorRTModule._pack_binding_names(self.output_binding_names), str(int(self.hardware_compatible)), - self.encode_metadata(self.settings), + self.encode_metadata(metadata), ] ) - def encode_metadata(self, settings: Any) -> str: - settings = copy.deepcopy(settings) - settings.torch_executed_ops = { - f"torch.ops.{op.__str__()}" for op in settings.torch_executed_ops + def encode_metadata(self, metadata: Any) -> str: + metadata = copy.deepcopy(metadata) + metadata["settings"].torch_executed_ops = { + f"torch.ops.{op.__str__()}" + for op in metadata["settings"].torch_executed_ops } - dumped_settings = pickle.dumps(settings) - encoded_settings = base64.b64encode(dumped_settings).decode("utf-8") - return encoded_settings + dumped_metadata = pickle.dumps(metadata) + encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8") + return encoded_metadata @staticmethod - def decode_metadata(encoded_settings: bytes) -> Any: - dumped_settings = base64.b64decode(encoded_settings.encode("utf-8")) - settings = pickle.loads(dumped_settings) - settings.torch_executed_ops = {eval(op) for op in settings.torch_executed_ops} - return settings + def decode_metadata(encoded_metadata: bytes) -> Any: + dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8")) + metadata = pickle.loads(dumped_metadata) + metadata["settings"].torch_executed_ops = { + eval(op) for op in metadata["settings"].torch_executed_ops + } + return metadata def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: if self.engine is None and self.serialized_engine is not None: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 82e655d736..c642ae0675 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -9,10 +9,9 @@ import torch import torch.nn.functional as F import torch_tensorrt as torchtrt +import torch_tensorrt as torch_trt import torchvision.models as models from torch import nn - -# from torch import nn from torch_tensorrt.dynamo import refit_module_weights from torch_tensorrt.dynamo._refit import ( construct_refit_mapping, @@ -29,6 +28,10 @@ assertions = unittest.TestCase() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit def test_mapping(): @@ -81,8 +84,112 @@ def test_mapping(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_refit_one_engine_with_weightmap(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_refit_one_engine_no_map_with_weightmap(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + trt_gm._run_on_acc_0.weight_name_map = None + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_refit_one_engine(): +def test_refit_one_engine_with_wrong_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -104,11 +211,18 @@ def test_refit_one_engine(): min_block_size=min_block_size, make_refitable=True, ) + # Manually Deleted all batch norm layer. This suppose to fail the fast refit + trt_gm._run_on_acc_0.weight_name_map = { + k: v + for k, v in trt_gm._run_on_acc_0.weight_name_map.items() + if "[SCALE]" not in k + } new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, + use_weight_map_cache=True, ) # Check the output @@ -125,8 +239,12 @@ def test_refit_one_engine(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_refit_one_engine_bert(): +def test_refit_one_engine_bert_with_weightmap(): inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] @@ -155,6 +273,7 @@ def test_refit_one_engine_bert(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, + use_weight_map_cache=True, ) # Check the output @@ -175,8 +294,12 @@ def test_refit_one_engine_bert(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_refit_one_engine_inline_runtime(): +def test_refit_one_engine_inline_runtime__with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -204,6 +327,7 @@ def test_refit_one_engine_inline_runtime(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, + use_weight_map_cache=True, ) # Check the output @@ -221,7 +345,7 @@ def test_refit_one_engine_inline_runtime(): @pytest.mark.unit -def test_refit_one_engine_python_runtime(): +def test_refit_one_engine_python_runtime_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -248,6 +372,7 @@ def test_refit_one_engine_python_runtime(): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, + use_weight_map_cache=True, ) # Check the output @@ -264,8 +389,282 @@ def test_refit_one_engine_python_runtime(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_refit_multiple_engine_with_weightmap(): + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval().to("cuda") + model2 = net().eval().to("cuda") + + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + torch_executed_ops = {torch.ops.aten.convolution.default} + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + torch_executed_ops=torch_executed_ops, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_refit_one_engine_without_weightmap(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=False, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_refit_one_engine_bert_without_weightmap(): + inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + ] + model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + model2 = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + nn.init.xavier_normal_(model2.embeddings.word_embeddings.weight) + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=False, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + if not isinstance(expected_output, torch.Tensor) or not isinstance( + refitted_output, torch.Tensor + ): + continue + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@pytest.mark.unit +def test_refit_one_engine_inline_runtime_without_weightmap(): + trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) + trt_gm = torch.export.load(trt_ep_path) + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=False, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_one_engine_python_runtime_without_weightmap(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = True + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=False, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + + # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) @pytest.mark.unit -def test_refit_multiple_engine(): +def test_refit_multiple_engine_without_weightmap(): class net(nn.Module): def __init__(self): @@ -314,6 +713,7 @@ def forward(self, x): compiled_module=trt_gm, new_weight_module=exp_program2, arg_inputs=inputs, + use_weight_map_cache=False, ) # Check the output