Skip to content

Commit

Permalink
refactor with new design
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Sep 20, 2024
1 parent 77b1cfc commit 99b77e2
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 104 deletions.
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator: object = None,
allow_shape_tensors: bool = False,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS,
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -586,6 +588,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -659,6 +663,8 @@ def convert_exported_program_to_serialized_trt_engine(
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"timing_cache_path": timing_cache_path,
"refit_identical_engine_weights": refit_identical_engine_weights,
"strip_engine_weights": strip_engine_weights,
}

exported_program = pre_export_lowering(exported_program)
Expand Down
68 changes: 53 additions & 15 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def run(
cached_data = self.engine_cache.check(hash_val)
if cached_data is not None: # hit the cache
(
unrefitted_serialized_engine,
serialized_engine,
self._input_names,
self._output_names,
cached_engine_input_specs,
Expand Down Expand Up @@ -582,12 +582,38 @@ def run(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)

# refit the cached engine with the new graph module
if not self.compilation_settings.strip_engine_weights:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)

# Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(
trt.SerializationFlag.EXCLUDE_WEIGHTS
)
serialized_engine = engine.serialize_with_config(
serialization_config
)

with io.BytesIO() as engine_bytes:
engine_bytes.write(unrefitted_serialized_engine)
unrefitted_engine_str = engine_bytes.getvalue()
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
unrefitted_engine_str,
engine_str,
self._input_names,
self._output_names,
self.weight_name_map,
Expand All @@ -609,32 +635,44 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)

# if strip_engine_weights is true, the serialized engine need to be refitted before using
maybe_unrefitted_serialized_engine = self.builder.build_serialized_network(
serialized_engine = self.builder.build_serialized_network(
self.ctx.net, builder_config
)
assert maybe_unrefitted_serialized_engine
assert serialized_engine

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(
f"TRT Engine uses: {maybe_unrefitted_serialized_engine.nbytes} bytes of Memory"
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

self._save_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

# if strip_engine_weights is true, the weight-stripped engine will be saved in engine cache
if (
self.engine_cache is not None
and self.compilation_settings.cache_built_engines
):
assert (
self.compilation_settings.make_refittable
), "weight-stripped engines must be refittable, please set make_refittable=True"

# no matter what compilation_settings is, we cache the weight-stripped engine
if self.compilation_settings.strip_engine_weights:
weight_stripped_serialized_engine = serialized_engine
else:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)
serialization_config = engine.create_serialization_config()
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
weight_stripped_serialized_engine = engine.serialize_with_config(
serialization_config
)

self.engine_cache.insert(
hash_val,
(
maybe_unrefitted_serialized_engine,
weight_stripped_serialized_engine,
self._input_names,
self._output_names,
self.input_specs,
Expand All @@ -644,11 +682,11 @@ def run(
)

with io.BytesIO() as engine_bytes:
engine_bytes.write(maybe_unrefitted_serialized_engine)
maybe_unrefitted_engine_str = engine_bytes.getvalue()
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
maybe_unrefitted_engine_str,
engine_str,
self._input_names,
self._output_names,
self.weight_name_map,
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,4 @@ def convert_module(
name=name,
settings=settings,
weight_name_map=interpreter_result.weight_name_map,
graph_module=module,
)
51 changes: 1 addition & 50 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import Platform, dtype
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.logging import TRT_LOGGER
from torch_tensorrt.runtime._utils import (
_is_switch_required,
Expand All @@ -39,7 +39,6 @@ def __init__(
name: str = "",
settings: CompilationSettings = CompilationSettings(),
weight_name_map: Optional[dict[Any, Any]] = None,
graph_module: torch.fx.GraphModule = None,
):
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine
Expand All @@ -53,7 +52,6 @@ def __init__(
name (str): Name for module
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
graph_module (torch.fx.GraphModule): GraphModule used to refit the weights
Example:
Expand Down Expand Up @@ -108,7 +106,6 @@ def __init__(
self.settings = settings
self.engine = None
self.weight_name_map = weight_name_map
self.graph_module = graph_module # may be used to refit the weights
self.target_platform = Platform.current_platform()

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
Expand All @@ -124,52 +121,6 @@ def setup_engine(self) -> None:
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
self.context = self.engine.create_execution_context()

if self.settings.strip_engine_weights:
assert (
self.settings.make_refittable
), "weight-stripped engines must be refittable, please set make_refittable=True"

# Refit the weights
refitter = trt.Refitter(self.engine, TRT_LOGGER)
refittable_weights = refitter.get_all_weights()
torch_device = get_model_device(self.graph_module)

for layer_name in refittable_weights:
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)
from torch_tensorrt.dynamo._refit import (
construct_refit_mapping_from_weight_name_map,
)

mapping = construct_refit_mapping_from_weight_name_map(
self.weight_name_map, self.graph_module.state_dict()
)

for layer_name in refittable_weights:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(
layer_name, trt_wt_tensor, trt_wt_location
)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"

# Refit the engine
if refitter.refit_cuda_engine():
logger.info("Engine refitted successfully!")
else:
logger.info("Engine refit failed!")

assert self.engine.num_io_tensors == (
len(self.input_names) + len(self.output_names)
)
Expand Down
3 changes: 0 additions & 3 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
name: str = "",
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed
weight_name_map: Optional[dict[Any, Any]] = None,
graph_module: torch.fx.GraphModule = None,
):
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines
Expand All @@ -98,7 +97,6 @@ def __init__(
name (str): Name for module
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
graph_module (torch.fx.GraphModule): GraphModule used to refit the weights
Example:
Expand Down Expand Up @@ -132,7 +130,6 @@ def __init__(
self.hardware_compatible = settings.hardware_compatible
self.settings = copy.deepcopy(settings)
self.weight_name_map = weight_name_map
self.graph_module = graph_module
self.serialized_engine = serialized_engine
self.engine = None

Expand Down
6 changes: 3 additions & 3 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_dynamo_compile_change_input_shape(self):
trt_gm = torch_trt.dynamo.compile(
torch.export.export(model, args=inputs),
inputs=inputs,
use_python_runtime=True,
use_python_runtime=False,
enabled_precisions={torch.float},
debug=False,
min_block_size=1,
Expand Down Expand Up @@ -387,7 +387,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"use_python_runtime": False,
"enabled_precisions": {torch.float},
"debug": False,
"min_block_size": 1,
Expand Down Expand Up @@ -452,7 +452,7 @@ def test_torch_compile_with_custom_engine_cache(self):
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"use_python_runtime": False,
"enabled_precisions": {torch.float},
"debug": False,
"min_block_size": 1,
Expand Down
Loading

0 comments on commit 99b77e2

Please sign in to comment.