Skip to content

Commit

Permalink
support weight-stripped engine and REFIT_IDENTICAL flag
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Sep 19, 2024
1 parent bc93437 commit 77b1cfc
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 45 deletions.
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def compile(
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS,
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -158,6 +160,8 @@ def compile(
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -281,6 +285,8 @@ def compile(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"refit_identical_engine_weights": refit_identical_engine_weights,
"strip_engine_weights": strip_engine_weights,
}

settings = CompilationSettings(**compilation_options)
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
CUSTOM_ENGINE_CACHE = None
REFIT_IDENTICAL_ENGINE_WEIGHTS = False
STRIP_ENGINE_WEIGHTS = False


def default_device() -> Device:
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
REFIT_IDENTICAL_ENGINE_WEIGHTS,
REQUIRE_FULL_COMPILATION,
REUSE_CACHED_ENGINES,
SPARSE_WEIGHTS,
STRIP_ENGINE_WEIGHTS,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_FAST_PARTITIONER,
Expand Down Expand Up @@ -78,6 +80,8 @@ class CompilationSettings:
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
refit_identical_engine_weights (bool): Whether to refit the engine with identical weights
strip_engine_weights (bool): Whether to strip the engine weights
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -112,6 +116,8 @@ class CompilationSettings:
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
refit_identical_engine_weights: bool = REFIT_IDENTICAL_ENGINE_WEIGHTS
strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand All @@ -124,6 +130,8 @@ class CompilationSettings:
"make_refittable",
"engine_capability",
"hardware_compatible",
"refit_identical_engine_weights",
"strip_engine_weights",
)


Expand Down
62 changes: 30 additions & 32 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -43,7 +44,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -283,7 +283,16 @@ def _populate_trt_builder_config(
builder_config.clear_flag(trt.BuilderFlag.TF32)

if self.compilation_settings.make_refittable:
builder_config.set_flag(trt.BuilderFlag.REFIT)
if version.parse(trt.__version__) >= version.parse("10.0"):
if self.compilation_settings.refit_identical_engine_weights:
builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL)
else:
builder_config.set_flag(trt.BuilderFlag.REFIT)
else:
builder_config.set_flag(trt.BuilderFlag.REFIT)

if self.compilation_settings.strip_engine_weights:
builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN)

if strict_type_constraints:
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
Expand Down Expand Up @@ -542,7 +551,7 @@ def run(
cached_data = self.engine_cache.check(hash_val)
if cached_data is not None: # hit the cache
(
serialized_engine,
unrefitted_serialized_engine,
self._input_names,
self._output_names,
cached_engine_input_specs,
Expand Down Expand Up @@ -573,31 +582,12 @@ def run(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)

runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

# TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers.
# We set weight_name_map=None to use slow refit anyway for now. Will fix it in the future.
_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=None,
)

serialized_engine = engine.serialize()

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
engine_bytes.write(unrefitted_serialized_engine)
unrefitted_engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
engine_str,
unrefitted_engine_str,
self._input_names,
self._output_names,
self.weight_name_map,
Expand All @@ -619,27 +609,32 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)

serialized_engine = self.builder.build_serialized_network(
# if strip_engine_weights is true, the serialized engine need to be refitted before using
maybe_unrefitted_serialized_engine = self.builder.build_serialized_network(
self.ctx.net, builder_config
)
assert serialized_engine
assert maybe_unrefitted_serialized_engine

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
_LOGGER.info(
f"TRT Engine uses: {maybe_unrefitted_serialized_engine.nbytes} bytes of Memory"
)

self._save_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

# if strip_engine_weights is true, the weight-stripped engine will be saved in engine cache
if (
self.engine_cache is not None
and self.compilation_settings.cache_built_engines
):
self.engine_cache.insert(
hash_val,
(
serialized_engine,
maybe_unrefitted_serialized_engine,
self._input_names,
self._output_names,
self.input_specs,
Expand All @@ -649,11 +644,14 @@ def run(
)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
engine_bytes.write(maybe_unrefitted_serialized_engine)
maybe_unrefitted_engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
engine_str, self._input_names, self._output_names, self.weight_name_map
maybe_unrefitted_engine_str,
self._input_names,
self._output_names,
self.weight_name_map,
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Any, List, Optional, Sequence

import tensorrt as trt
import torch
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
Expand All @@ -18,8 +19,6 @@
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -166,4 +165,5 @@ def convert_module(
name=name,
settings=settings,
weight_name_map=interpreter_result.weight_name_map,
graph_module=module,
)
57 changes: 53 additions & 4 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
from tempfile import tempdir
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
import torch
import torch_tensorrt
from torch.nn import Module
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import Platform, dtype
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device
from torch_tensorrt.logging import TRT_LOGGER
from torch_tensorrt.runtime._utils import (
_is_switch_required,
_select_rt_device,
multi_gpu_device_check,
)

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand All @@ -39,7 +38,8 @@ def __init__(
*,
name: str = "",
settings: CompilationSettings = CompilationSettings(),
weight_name_map: Any = None,
weight_name_map: Optional[dict[Any, Any]] = None,
graph_module: torch.fx.GraphModule = None,
):
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine
Expand All @@ -52,6 +52,8 @@ def __init__(
Keyword Arguments:
name (str): Name for module
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
graph_module (torch.fx.GraphModule): GraphModule used to refit the weights
Example:
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(
self.settings = settings
self.engine = None
self.weight_name_map = weight_name_map
self.graph_module = graph_module # may be used to refit the weights
self.target_platform = Platform.current_platform()

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
Expand All @@ -121,6 +124,52 @@ def setup_engine(self) -> None:
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
self.context = self.engine.create_execution_context()

if self.settings.strip_engine_weights:
assert (
self.settings.make_refittable
), "weight-stripped engines must be refittable, please set make_refittable=True"

# Refit the weights
refitter = trt.Refitter(self.engine, TRT_LOGGER)
refittable_weights = refitter.get_all_weights()
torch_device = get_model_device(self.graph_module)

for layer_name in refittable_weights:
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)
from torch_tensorrt.dynamo._refit import (
construct_refit_mapping_from_weight_name_map,
)

mapping = construct_refit_mapping_from_weight_name_map(
self.weight_name_map, self.graph_module.state_dict()
)

for layer_name in refittable_weights:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(
layer_name, trt_wt_tensor, trt_wt_location
)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"

# Refit the engine
if refitter.refit_cuda_engine():
logger.info("Engine refitted successfully!")
else:
logger.info("Engine refit failed!")

assert self.engine.num_io_tensors == (
len(self.input_names) + len(self.output_names)
)
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
name: str = "",
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed
weight_name_map: Optional[dict[Any, Any]] = None,
graph_module: torch.fx.GraphModule = None,
):
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines
Expand All @@ -96,6 +97,8 @@ def __init__(
Keyword Arguments:
name (str): Name for module
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
graph_module (torch.fx.GraphModule): GraphModule used to refit the weights
Example:
Expand Down Expand Up @@ -129,6 +132,7 @@ def __init__(
self.hardware_compatible = settings.hardware_compatible
self.settings = copy.deepcopy(settings)
self.weight_name_map = weight_name_map
self.graph_module = graph_module
self.serialized_engine = serialized_engine
self.engine = None

Expand Down
Loading

0 comments on commit 77b1cfc

Please sign in to comment.