Skip to content

Commit

Permalink
chore: rebase and update
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Oct 14, 2024
1 parent 26136f0 commit 9842b00
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
5 changes: 3 additions & 2 deletions examples/dynamo/weight_streaming_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def time_generate(model, inputs, output_seq_length, iterations=10):
# Compiler option
# ----------------------------------
#
# enable_weight_streaming=True option and use_strong_types=True are required to build
# the engine with weight streaming feature. use_strong_types=True option creates a
# enable_weight_streaming=True option and use_explicit_typing=True are required to build
# the engine with weight streaming feature. use_explicit_typing=True option creates a
# `strongly typed network <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#strongly-typed-networks>`_ and only float32 precision is allowed in enabled_precisions option
#
trt_model = torch_tensorrt.dynamo.compile(
Expand All @@ -95,6 +95,7 @@ def time_generate(model, inputs, output_seq_length, iterations=10):
enabled_precisions={torch.float32},
truncate_double=True,
device=DEVICE,
use_explicit_typing=True,
enable_weight_streaming=True,
)

Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def compile(
This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
)

if enable_weight_streaming and not use_explicit_typing:
raise AssertionError(
"When enable_weight_streaming is enabled, it requires use_explicit_typing to be set to True"
)
# Aliasing inputs to arg_inputs for better understanding
if not arg_inputs and not inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
Expand Down Expand Up @@ -547,6 +551,7 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -607,6 +612,7 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
enable_weight_streaming (bool): Enable weight streaming.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -682,6 +688,7 @@ def convert_exported_program_to_serialized_trt_engine(
"timing_cache_path": timing_cache_path,
"use_explicit_typing": use_explicit_typing,
"use_fp32_acc": use_fp32_acc,
"enable_weight_streaming": enable_weight_streaming,
}

exported_program = pre_export_lowering(exported_program)
Expand Down
6 changes: 0 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,6 @@ def __init__(
)
flag |= STRONGLY_TYPED

if compilation_settings.enable_weight_streaming:
STRONGLY_TYPED = 1 << (int)(
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED
)
flag |= STRONGLY_TYPED

self.ctx = ConversionContext(
self.builder.create_network(flag), compilation_settings
)
Expand Down
4 changes: 4 additions & 0 deletions tests/py/dynamo/runtime/test_004_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_weight_streaming_default(self, _, use_python_runtime):
cache_built_engines=False,
reuse_cached_engines=False,
use_python_runtime=use_python_runtime,
use_explicit_typing=True,
enable_weight_streaming=True,
)
# Checking if default weight streaming budget(automatic) is applied when compiler option was provided
Expand Down Expand Up @@ -99,6 +100,7 @@ def test_weight_streaming_manual(self, _, use_python_runtime):
cache_built_engines=False,
reuse_cached_engines=False,
use_python_runtime=use_python_runtime,
use_explicit_typing=True,
enable_weight_streaming=True,
)
# Weight streaming budget is applied manually.
Expand Down Expand Up @@ -163,6 +165,7 @@ def test_weight_streaming_invalid_usage(self, _, use_python_runtime, multi_rt):
{"torch.ops.aten.convolution.default"} if multi_rt else {}
),
use_python_runtime=use_python_runtime,
use_explicit_typing=True,
enable_weight_streaming=True,
)

Expand Down Expand Up @@ -209,6 +212,7 @@ def test_weight_streaming_multi_rt(self, _, use_python_runtime):
reuse_cached_engines=False,
torch_executed_ops={"torch.ops.aten.convolution.default"},
use_python_runtime=use_python_runtime,
use_explicit_typing=True,
enable_weight_streaming=True,
)

Expand Down

0 comments on commit 9842b00

Please sign in to comment.