diff --git a/py/torch_tensorrt/dynamo/_engine_caching.py b/py/torch_tensorrt/dynamo/_engine_caching.py index ee5a6ec854..c8ff7aba50 100644 --- a/py/torch_tensorrt/dynamo/_engine_caching.py +++ b/py/torch_tensorrt/dynamo/_engine_caching.py @@ -48,7 +48,7 @@ def pack( serialized_engine: bytes, input_names: List[str], output_names: List[str], - weight_name_map: Optional[Dict[str, Any]], + weight_name_map: Optional[Dict[Any, Any]], ) -> bytes: """Pack serialized engine, input names, output names, and weight map into a single blob @@ -56,7 +56,7 @@ def pack( serialized_engine (bytes): serialized TRT engine input_names (List[str]): input names of TRT engine output_names (List[str]): output names of TRT engine - weight_name_map (Optional[Dict[str, Any]]): weight name map for refitting + weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting Returns: bytes: packed blob @@ -73,7 +73,7 @@ def pack( @staticmethod def unpack( packed_obj: bytes, - ) -> Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: + ) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]: """Unpack packed blob into serialized engine, input names, output names, and weight map Args: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index c944e8a971..30ce62960e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -501,7 +501,6 @@ def run( _LOGGER.info( "Found the cached engine that corresponds to this graph. It is directly loaded." ) - # TODO: refit the engine here or outside (within convert_module)? return TRTInterpreterResult( serialized_engine, self._input_names, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e03c6cf832..5a8f8d62bf 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -107,6 +107,36 @@ def interpret_module_to_result( compilation_settings=settings, ) interpreter_result = interpreter.run() + + if settings.make_refitable: + # Run fast refit even if it's the first compilation. + # This is to ensure that the weight name map is correct for future refits. + # If the fast refit fails, remove the weight name map. + from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm + from torch_tensorrt.logging import TRT_LOGGER + + runtime = trt.Runtime(TRT_LOGGER) + refit_test_engine = runtime.deserialize_cuda_engine( + interpreter_result.serialized_engine + ) + try: + _refit_single_trt_engine_with_gm( + new_gm=module, + old_engine=refit_test_engine, + input_list=inputs, + settings=settings, + weight_name_map=interpreter_result.weight_name_map, + ) + except AssertionError: + # TRTInterpreterResult is a tuple, so we need to create a new one + interpreter_result = TRTInterpreterResult( + interpreter_result.serialized_engine, + interpreter_result.input_names, + interpreter_result.output_names, + None, + ) + logger.warning("Fast refit test failed. Removing the weight map caching.") + return interpreter_result @@ -126,28 +156,6 @@ def convert_module( PythonTorchTensorRTModule or TorchTensorRTModule """ interpreter_result = interpret_module_to_result(module, inputs, settings) - # Test fast refit: - from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm - from torch_tensorrt.logging import TRT_LOGGER - - runtime = trt.Runtime(TRT_LOGGER) - refit_test_engine = runtime.deserialize_cuda_engine( - interpreter_result.serialized_engine - ) - weight_name_map: Any = None - # Do the test refit with cached map if make_refitable is enabled - if settings.make_refitable: - weight_name_map = interpreter_result.weight_name_map - try: - _refit_single_trt_engine_with_gm( - new_gm=module, - old_engine=refit_test_engine, - input_list=inputs, - settings=settings, - weight_name_map=interpreter_result.weight_name_map, - ) - except AssertionError: - logger.warning("Fast refit test failed. Removing the weight map caching.") rt_cls = PythonTorchTensorRTModule @@ -171,5 +179,5 @@ def convert_module( output_binding_names=list(interpreter_result.output_names), name=name, settings=settings, - weight_name_map=weight_name_map, + weight_name_map=interpreter_result.weight_name_map, )