diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index 84122e074b..a10c0e17ae 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -34,7 +34,7 @@ "make_refitable": True, } -model = models.resnet18(pretrained=False).eval().to("cuda") +model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. mutable_module(*inputs) @@ -45,7 +45,7 @@ # %% # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. -model2 = models.resnet18(pretrained=True).eval().to("cuda") +model2 = models.resnet18(pretrained=False).eval().to("cuda") mutable_module.load_state_dict(model2.state_dict()) diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index c47ed19ebb..c8cd5590d3 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -39,7 +39,7 @@ # Compile the module for the first time and save it. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -model = models.resnet18(pretrained=False).eval().to("cuda") +model = models.resnet18(pretrained=True).eval().to("cuda") exp_program = torch.export.export(model, tuple(inputs)) enabled_precisions = {torch.float} debug = False @@ -68,7 +68,7 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # Create and compile the updated model -model2 = models.resnet18(pretrained=True).eval().to("cuda") +model2 = models.resnet18(pretrained=False).eval().to("cuda") exp_program2 = torch.export.export(model2, tuple(inputs)) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c97c3a6229..0e5e09de8a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -56,9 +56,9 @@ def compile( disable_tf32: bool = _defaults.DISABLE_TF32, assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - enabled_precisions: ( - Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] - ) = _defaults.ENABLED_PRECISIONS, + enabled_precisions: Union[ + Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] + ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, make_refitable: bool = _defaults.MAKE_REFITABLE, debug: bool = _defaults.DEBUG, diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 660cb8a875..4ce7d0b150 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -34,7 +34,7 @@ TorchTensorRTModule, ) from torch_tensorrt.dynamo.utils import ( - check_output, + check_module_output, get_torch_inputs, set_log_level, to_torch_device, @@ -115,19 +115,8 @@ def construct_refit_mapping_from_weight_name_map( for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) - if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]: - # Batch Norm Layer - params = {} - for w in sd_weight_name: - params[w.split(".")[-1]] = state_dict[w] - scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7) - shift = params["bias"] - params["running_mean"] * scale - # Set scale to scale or shift to shift - engine_weight_map[engine_weight_name] = eval( - engine_weight_name.split(" ")[-1].lower() - ) - elif sd_weight_name not in state_dict: + if sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue else: @@ -157,16 +146,25 @@ def _refit_single_trt_engine_with_gm( """ refitted = set() - + torch_device = list(new_gm.state_dict().values())[0].device.type refitter = trt.Refitter(old_engine, TRT_LOGGER) weight_list = refitter.get_all_weights() if weight_name_map: # Get the refitting mapping - trt_wt_location = trt.TensorLocation.DEVICE + trt_wt_location = ( + trt.TensorLocation.DEVICE + if torch_device == "cuda" + else trt.TensorLocation.HOST + ) mapping = construct_refit_mapping_from_weight_name_map( weight_name_map, new_gm.state_dict() ) + + # Debug Use + # correct = construct_refit_mapping(new_gm, input_list, settings) + # comparison = {k: (np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2), correct[k][0], mapping[k][0]) for k in mapping if k in correct} + for layer_name in weight_list: if layer_name not in mapping: logger.warning(f"{layer_name} is not found in weight mapping.") @@ -235,7 +233,7 @@ def refit_module_weights( compiled_module = copy.deepcopy(compiled_module) elif inline_module: raise AssertionError( - "Exported program does not support modifying in place. Please set inplace to false and use the returned graph module." + "Exported program does not support modifying in place. Please set in_place to false and use the returned graph module." ) # Get the settings and check the setting to be uniform @@ -283,6 +281,7 @@ def refit_module_weights( arg_inputs = [arg_inputs] torch_inputs = get_torch_inputs(arg_inputs, device) + torch_kwarg_inputs: Any = {} if kwarg_inputs: torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) runtime = trt.Runtime(TRT_LOGGER) @@ -436,6 +435,7 @@ def refit_module_weights( settings=settings, weight_name_map=weight_name_map, ) + except AssertionError as e: # If fast_refit is used and failed, we fall back to regular refit logger.warning(e) @@ -463,7 +463,7 @@ def refit_module_weights( setattr(compiled_module, f"{name}_engine", refitted_engine) if verify_output and arg_inputs is not None: - if check_output( + if check_module_output( new_module=new_gm, refitted_module=compiled_module, arg_inputs=torch_inputs, @@ -471,6 +471,19 @@ def refit_module_weights( ): logger.info("Refitting Succeed!") else: + if weight_name_map: + logger.warning( + "Refitting with weight_name_map yielded incorrect result! The outputs do not match." + ) + return refit_module_weights( + compiled_module, + new_weight_module, + arg_inputs, + kwarg_inputs, + verify_output, + use_weight_map_cache=False, + in_place=in_place, + ) logger.error("Refitting Failed! The outputs do not match.") else: logger.info("Refitting Completed! Output verification skipped.") diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 17437ceb6e..9fef61961b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -1,9 +1,21 @@ +import gc import io import logging import os import warnings from datetime import datetime -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Union, +) import numpy as np import torch @@ -26,7 +38,7 @@ get_node_name, get_trt_tensor, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -327,6 +339,39 @@ def _construct_trt_network_def(self) -> None: f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" ) + @staticmethod + def find_weight( + weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any] + ) -> str: + """ + We need to build map from engine weight name to state_dict weight name. + The purpose of this function is to find the corresponding weight name in module state_dict. + + weight_name: the target weight name we want to search for + np_map: the map from weight name to np values in INetworkDefinition + state_dict: state of the graph module + """ + network_weight = np_map[weight_name] + network_weight = torch.from_numpy(np_map[weight_name]).cuda() + for sd_w_name, sd_weight in state_dict.items(): + if TRTInterpreter.check_weight_equal(sd_weight, network_weight): + del state_dict[sd_w_name] + return sd_w_name + return "" + + @staticmethod + def check_weight_equal( + sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray] + ) -> Any: + if not isinstance(network_weight, torch.Tensor): + network_weight = torch.from_numpy(network_weight).cuda() + try: + return sd_weight.shape == network_weight.shape and torch.all( + torch.abs(sd_weight - network_weight) < 0.01 + ) + except Exception: + return torch.all(sd_weight == network_weight) + def _save_weight_mapping(self) -> None: """ Construct the weight name mapping from engine weight name to state_dict weight name. @@ -336,23 +381,6 @@ def _save_weight_mapping(self) -> None: 2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict """ - def find_weight( - weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] - ) -> str: - network_weight = np_map[weight_name] - for sd_w_name, sd_weight in sd.items(): - if check_weight_equal(sd_weight, network_weight): - return sd_w_name - return "" - - def check_weight_equal( - sd_weight: torch.tensor, network_weight: np.ndarray - ) -> Any: - sd_weight = sd_weight.reshape(-1).cpu().numpy() - return sd_weight.size == network_weight.size and np.allclose( - sd_weight, network_weight, 1e-1, 1e-1 - ) - MODULE_MAP = { "SCALE": ( trt.IScaleLayer, @@ -398,8 +426,19 @@ def check_weight_equal( ) } """ + _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping sd = self.module.state_dict() + torch_device = to_torch_device(self.compilation_settings.device) + gm_is_on_cuda = list(sd.values())[0].device.type == "cuda" + if not gm_is_on_cuda: + # If the model original position is on CPU, move it GPU + sd = { + k: v.reshape(-1).to(torch_device) + for k, v in self.module.state_dict().items() + } + else: + sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} np_map = {} net = self.ctx.net @@ -448,10 +487,10 @@ def check_weight_equal( if "SCALE" in engine_weight_name: # There is no direct connection in batch_norm layer. So skip it pass - elif sd_weight_name not in sd or not check_weight_equal( + elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal( sd[sd_weight_name], np_map[engine_weight_name] ): - weight_name_map[engine_weight_name] = find_weight( + weight_name_map[engine_weight_name] = TRTInterpreter.find_weight( engine_weight_name, np_map, sd ) @@ -462,6 +501,10 @@ def check_weight_equal( self.weight_name_map = weight_name_map + del np_map, sd + gc.collect() + torch.cuda.empty_cache() + def run( self, strict_type_constraints: bool = False, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e03c6cf832..4cedcb80cb 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -130,14 +130,13 @@ def convert_module( from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm from torch_tensorrt.logging import TRT_LOGGER - runtime = trt.Runtime(TRT_LOGGER) - refit_test_engine = runtime.deserialize_cuda_engine( - interpreter_result.serialized_engine - ) weight_name_map: Any = None # Do the test refit with cached map if make_refitable is enabled if settings.make_refitable: - weight_name_map = interpreter_result.weight_name_map + runtime = trt.Runtime(TRT_LOGGER) + refit_test_engine = runtime.deserialize_cuda_engine( + interpreter_result.serialized_engine + ) try: _refit_single_trt_engine_with_gm( new_gm=module, @@ -146,9 +145,13 @@ def convert_module( settings=settings, weight_name_map=interpreter_result.weight_name_map, ) + weight_name_map = interpreter_result.weight_name_map except AssertionError: logger.warning("Fast refit test failed. Removing the weight map caching.") + del refit_test_engine + torch.cuda.empty_cache() + rt_cls = PythonTorchTensorRTModule if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime: diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index a437057d04..672a7e267d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -12,7 +12,11 @@ from torch_tensorrt.dynamo._compiler import compile as dynamo_compile from torch_tensorrt.dynamo._refit import refit_module_weights from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.utils import to_torch_device, to_torch_tensorrt_device +from torch_tensorrt.dynamo.utils import ( + check_output_equal, + to_torch_device, + to_torch_tensorrt_device, +) logger = logging.getLogger(__name__) @@ -228,7 +232,7 @@ def update_refit_condition(self) -> None: new_result = self.original_model(*args, **kwargs) self.original_model.cpu() torch.cuda.empty_cache() - if MutableTorchTensorRTModule.check_output_equal(result, new_result): + if check_output_equal(result, new_result): self.refit_state.set_state(RefitFlag.LIVE) return @@ -268,7 +272,12 @@ def refit_gm(self) -> None: ) ) self.gm = refit_module_weights( - self.gm, self.exp_program, use_weight_map_cache=True, in_place=True + self.gm, + self.exp_program, + self.arg_inputs, + self.kwarg_inputs, + use_weight_map_cache=True, + in_place=True, ) self.original_model.cpu() @@ -426,46 +435,6 @@ def __setattr__(self, name: str, value: Any) -> None: else: object.__setattr__(self, name, value) - @staticmethod - def check_output_equal( - output1: Any, - output2: Any, - ) -> bool: - # TODO: Move this to utils when all PRs are merged. This can be used by other functions. - if type(output1) != type(output2): - logger.warning( - "This module does not support using output verification to skip refit. Refit will be performed \ - whenever the state is UNKNOWN" - ) - return False - - if isinstance(output1, torch.Tensor): - if output1.shape != output2.shape: - return False - return torch.allclose(output1, output2, 1e-2, 1e-2) # type: ignore - - elif isinstance(output1, (tuple, list)): - if len(output1) != len(output2): - return False - for a, b in zip(output1, output2): - if not MutableTorchTensorRTModule.check_output_equal(a, b): - return False - return True - - elif isinstance(output1, dict): - if output1.keys() != output2.keys(): - return False - for a, b in zip(output1.values(), output2.values()): - if not MutableTorchTensorRTModule.check_output_equal(a, b): - return False - return True - - logger.warning( - "This module does not support using output verification to skip refit. Refit will be performed \ - whenever the state is UNKNOWN" - ) - return False - @staticmethod def check_inputs_equal( input1: Any, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1d7785717b..6d74ab61bf 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -22,6 +22,8 @@ COSINE_THRESHOLD = 0.99 DYNAMIC_DIM = -1 +RTOL = 5e-3 +ATOL = 5e-3 class Frameworks(Enum): @@ -394,7 +396,7 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any: return nested_decorator -def check_output( +def check_module_output( new_module: torch.fx.GraphModule, refitted_module: torch.fx.GraphModule, arg_inputs: Any, @@ -403,14 +405,50 @@ def check_output( old_outputs, new_outputs = refitted_module(*arg_inputs), new_module( *arg_inputs, **kwarg_inputs ) - for old_output, new_output in zip(old_outputs, new_outputs): - if isinstance(old_output, torch.Tensor) and isinstance( - new_outputs, torch.Tensor - ): - if not torch.allclose(old_output, new_output, 1e-2, 1e-2): + if type(old_outputs) != type(new_outputs): + logger.warning("The output types are different. Output check is skipped.") + return True + return check_output_equal(old_outputs, new_outputs) + + +def check_output_equal( + output1: Any, + output2: Any, + rtol: float = RTOL, + atol: float = ATOL, +) -> bool: + + if type(output1) != type(output2): + logger.warning( + "The output types are different. Check_output_equal will always return false." + ) + return False + + if isinstance(output1, torch.Tensor): + if output1.shape != output2.shape: + return False + return torch.allclose(output1, output2, rtol, atol) # type: ignore + + elif isinstance(output1, (tuple, list)): + if len(output1) != len(output2): + return False + for a, b in zip(output1, output2): + if not check_output_equal(a, b): + return False + return True + + elif isinstance(output1, dict): + if output1.keys() != output2.keys(): + return False + for a, b in zip(output1.values(), output2.values()): + if not check_output_equal(a, b): return False + return True - return True + logger.warning( + "The output type is not supported to be checked. Check_output_equal will always return false." + ) + return False def get_flat_args_with_check( diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 6cdee663e6..df1e4ee934 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -23,7 +23,7 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_torch_inputs +from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_torch_inputs _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -60,8 +60,8 @@ def run_test( mod, inputs, interpreter, - rtol, - atol, + rtol=RTOL, + atol=ATOL, check_dtype=True, pyt_inputs=None, rt_cls=PythonTorchTensorRTModule, @@ -254,8 +254,8 @@ def run_test( self, mod, inputs, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, precision=dtype.f32, check_dtype=True, use_dynamo_tracer=False, @@ -374,8 +374,8 @@ def run_test_with_dynamic_shape( self, mod, input_specs, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py index 9cb63f4fdc..a29a8061db 100644 --- a/tests/py/dynamo/conversion/test_bitwise_and_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -5,6 +5,7 @@ from torch.export import Dim from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from torch_tensorrt.dynamo.utils import ATOL, RTOL from .harness import DispatchTestCase @@ -152,8 +153,8 @@ def forward(self, lhs_val, rhs_val): torch.testing.assert_close( out, ref, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, equal_nan=True, check_dtype=True, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 03bae9b68b..d935134ff2 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -3,6 +3,7 @@ from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from torch_tensorrt.dynamo.utils import ATOL, RTOL from .harness import DispatchTestCase @@ -501,8 +502,8 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): torch.testing.assert_close( out, ref, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, equal_nan=True, check_dtype=True, ) diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py index 839474a0dd..3d0b41b791 100644 --- a/tests/py/dynamo/conversion/test_index_select_aten.py +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -4,6 +4,7 @@ from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from torch_tensorrt.dynamo.utils import ATOL, RTOL from .harness import DispatchTestCase @@ -122,8 +123,8 @@ def forward(self, source_tensor, indice_tensor): torch.testing.assert_close( out, ref, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, equal_nan=True, check_dtype=True, ) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index c642ae0675..9782cd829c 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -35,8 +35,8 @@ @pytest.mark.unit def test_mapping(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] trt_input = [ torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format) @@ -117,6 +117,7 @@ def test_refit_one_engine_with_weightmap(): new_weight_module=exp_program2, arg_inputs=inputs, use_weight_map_cache=True, + verify_output=True, ) # Check the output @@ -140,8 +141,8 @@ def test_refit_one_engine_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_no_map_with_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -191,8 +192,8 @@ def test_refit_one_engine_no_map_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_with_wrong_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -301,8 +302,8 @@ def test_refit_one_engine_bert_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_inline_runtime__with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -347,8 +348,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): @pytest.mark.unit def test_refit_one_engine_python_runtime_with_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -467,8 +468,8 @@ def forward(self, x): @pytest.mark.unit def test_refit_one_engine_without_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -571,8 +572,8 @@ def test_refit_one_engine_bert_without_weightmap(): @pytest.mark.unit def test_refit_one_engine_inline_runtime_without_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -617,8 +618,8 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): @pytest.mark.unit def test_refit_one_engine_python_runtime_without_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index 593a4322b7..86e7678a66 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -11,10 +11,31 @@ import torchvision.models as models from torch import nn from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import RefitFlag +from torch_tensorrt.dynamo.utils import check_output_equal assertions = unittest.TestCase() +@pytest.mark.unit +def test_check_output_equal(): + torch.manual_seed(0) + a = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + torch.manual_seed(0) + b = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + assertions.assertTrue( + check_output_equal(a, b), + msg=f"test_check_output_equal is not correct.", + ) + + @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -31,8 +52,8 @@ def test_resnet18(): "make_refitable": True, } - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) @@ -44,9 +65,7 @@ def test_resnet18(): # Check the output expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -73,7 +92,7 @@ def test_save(): "make_refitable": True, } - model = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) @@ -83,9 +102,7 @@ def test_save(): loaded_outputs, trt_gm_outputs = reload(*inputs), mutable_module(*inputs) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - loaded_outputs, trt_gm_outputs - ), + check_output_equal(loaded_outputs, trt_gm_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -109,7 +126,7 @@ def test_resnet18_modify_attribute(): "make_refitable": True, } - model = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) @@ -150,7 +167,7 @@ def test_resnet18_modify_attribute_no_refit(): "make_refitable": True, } - model = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) @@ -241,9 +258,7 @@ def forward(self, x, b=5, c=None, d=None): *args, **kwargs ) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -306,9 +321,7 @@ def set_weights(self): model.cuda() expected_outputs, refitted_outputs = model(*args), mutable_module(*args) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -371,9 +384,7 @@ def set_layer(self): model.cuda() # move offloaded model from cpu to cuda expected_outputs, refitted_outputs = model(*args), mutable_module(*args) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -443,31 +454,9 @@ def forward(self, x, b=5, c=None, d=None): *args, **kwargs ) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) # Clean up model env torch._dynamo.reset() - - -@pytest.mark.unit -def test_check_output_equal(): - torch.manual_seed(0) - a = { - "a": torch.rand(10, 30), - "b": [torch.rand(10, 30), torch.rand(5, 5)], - "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, - } - torch.manual_seed(0) - b = { - "a": torch.rand(10, 30), - "b": [torch.rand(10, 30), torch.rand(5, 5)], - "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, - } - assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal(a, b), - msg=f"test_check_output_equal is not correct.", - )