Skip to content

Commit

Permalink
move refit into interpret_module_to_result
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Aug 24, 2024
1 parent 132547f commit a86260e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 27 deletions.
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/_engine_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def pack(
serialized_engine: bytes,
input_names: List[str],
output_names: List[str],
weight_name_map: Optional[Dict[str, Any]],
weight_name_map: Optional[Dict[Any, Any]],
) -> bytes:
"""Pack serialized engine, input names, output names, and weight map into a single blob
Args:
serialized_engine (bytes): serialized TRT engine
input_names (List[str]): input names of TRT engine
output_names (List[str]): output names of TRT engine
weight_name_map (Optional[Dict[str, Any]]): weight name map for refitting
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
Returns:
bytes: packed blob
Expand All @@ -73,7 +73,7 @@ def pack(
@staticmethod
def unpack(
packed_obj: bytes,
) -> Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]:
) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]:
"""Unpack packed blob into serialized engine, input names, output names, and weight map
Args:
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,6 @@ def run(
_LOGGER.info(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)
# TODO: refit the engine here or outside (within convert_module)?
return TRTInterpreterResult(
serialized_engine,
self._input_names,
Expand Down
54 changes: 31 additions & 23 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ def interpret_module_to_result(
compilation_settings=settings,
)
interpreter_result = interpreter.run()

if settings.make_refitable:
# Run fast refit even if it's the first compilation.
# This is to ensure that the weight name map is correct for future refits.
# If the fast refit fails, remove the weight name map.
from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm
from torch_tensorrt.logging import TRT_LOGGER

runtime = trt.Runtime(TRT_LOGGER)
refit_test_engine = runtime.deserialize_cuda_engine(
interpreter_result.serialized_engine
)
try:
_refit_single_trt_engine_with_gm(
new_gm=module,
old_engine=refit_test_engine,
input_list=inputs,
settings=settings,
weight_name_map=interpreter_result.weight_name_map,
)
except AssertionError:
# TRTInterpreterResult is a tuple, so we need to create a new one
interpreter_result = TRTInterpreterResult(
interpreter_result.serialized_engine,
interpreter_result.input_names,
interpreter_result.output_names,
None,
)
logger.warning("Fast refit test failed. Removing the weight map caching.")

return interpreter_result


Expand All @@ -126,28 +156,6 @@ def convert_module(
PythonTorchTensorRTModule or TorchTensorRTModule
"""
interpreter_result = interpret_module_to_result(module, inputs, settings)
# Test fast refit:
from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm
from torch_tensorrt.logging import TRT_LOGGER

runtime = trt.Runtime(TRT_LOGGER)
refit_test_engine = runtime.deserialize_cuda_engine(
interpreter_result.serialized_engine
)
weight_name_map: Any = None
# Do the test refit with cached map if make_refitable is enabled
if settings.make_refitable:
weight_name_map = interpreter_result.weight_name_map
try:
_refit_single_trt_engine_with_gm(
new_gm=module,
old_engine=refit_test_engine,
input_list=inputs,
settings=settings,
weight_name_map=interpreter_result.weight_name_map,
)
except AssertionError:
logger.warning("Fast refit test failed. Removing the weight map caching.")

rt_cls = PythonTorchTensorRTModule

Expand All @@ -171,5 +179,5 @@ def convert_module(
output_binding_names=list(interpreter_result.output_names),
name=name,
settings=settings,
weight_name_map=weight_name_map,
weight_name_map=interpreter_result.weight_name_map,
)

0 comments on commit a86260e

Please sign in to comment.