Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement FP32 accumulation for matmul #3110

Merged
merged 81 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
2ea181a
chore: add gpt2 example
peri044 Jun 13, 2024
37b65a5
chore: add llama2 example
peri044 Jun 13, 2024
bd12b12
Merge branch 'main' into llm_examples_main
peri044 Jun 13, 2024
4a9f73e
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
0387d0b
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
6193939
chore: updates
peri044 Jun 14, 2024
9d3296e
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
84fc49c
Merge branch 'main' into llm_examples_main
peri044 Jun 18, 2024
ff17d91
chore: rebase
peri044 Jun 18, 2024
8e6ba26
Merge branch 'llm_examples_main' of github.com:pytorch/TensorRT into …
peri044 Jun 24, 2024
67ec408
Merge branch 'main' into llm_examples_main
peri044 Jun 25, 2024
9af8e39
chore: remove aten.full decomposition
peri044 Jun 25, 2024
50d4096
chore: fix expand DS support
peri044 Jun 25, 2024
59febf5
chore: minor fix
peri044 Jun 26, 2024
c3e4382
chore: updates
peri044 Jun 26, 2024
0673db4
chore: add testcase
peri044 Jun 26, 2024
0b62f8f
Merge branch 'main' into full
peri044 Jun 26, 2024
54f6410
Merge branch 'full' into fix_expand_ds
peri044 Jun 26, 2024
ae3d6b2
Merge branch 'fix_expand_ds' into llm_examples_main
peri044 Jun 26, 2024
4464fd5
chore: updates
peri044 Jun 26, 2024
63b13cf
chore: updates
peri044 Jun 28, 2024
3d10b92
Merge branch 'main' into llm_examples_main
peri044 Jun 28, 2024
e97a94f
chore: updates
peri044 Jul 10, 2024
4f503a8
chore: updates
peri044 Jul 11, 2024
5ecf63e
chore: rebase
peri044 Jul 11, 2024
0d00d8c
chore: updates
peri044 Jul 11, 2024
8099003
chore: updates
peri044 Jul 11, 2024
457f706
chore: updates
peri044 Jul 11, 2024
ce3b2f8
chore: updates
peri044 Jul 11, 2024
d8acadc
chore: updates
peri044 Jul 12, 2024
262c87d
chore: updates
peri044 Jul 12, 2024
bb94dfd
chore: rebase
peri044 Jul 17, 2024
736b839
chore: updates
peri044 Jul 17, 2024
313380e
chore: bug fixes
peri044 Jul 18, 2024
1057d83
chore: updates
peri044 Jul 19, 2024
bfd0cf2
chore: fixes
peri044 Jul 20, 2024
17ddb31
chore: updates
peri044 Jul 20, 2024
88be4fa
chore: add torch compile gpt2 example
peri044 Jul 22, 2024
df825ab
chore: updates
peri044 Jul 22, 2024
ff07295
chore: add timing calculation
peri044 Jul 24, 2024
857b0aa
Merge branch 'main' into llm_examples_main
peri044 Jul 24, 2024
8fae56b
Merge branch 'main' into llm_examples_main
peri044 Jul 29, 2024
d483718
chore: rebase
peri044 Jul 31, 2024
397e4bc
Merge branch 'main' into llm_examples_main
peri044 Aug 5, 2024
6c9b9fe
chore: updates
peri044 Aug 5, 2024
6313b1c
chore: updates
peri044 Aug 9, 2024
d608cc5
chore: rebase
peri044 Aug 9, 2024
1327782
chore: rebase fixes
peri044 Aug 9, 2024
0980778
chore: updates
peri044 Aug 9, 2024
94b2ba1
chore: updates
peri044 Aug 9, 2024
2b1db29
chore: updates
peri044 Aug 9, 2024
9f606fc
chore: updates
peri044 Aug 9, 2024
0cf23be
Merge branch 'main' into llm_examples_main
peri044 Aug 14, 2024
3228c57
chore: Update perf tooling with support for HF models (#3034)
peri044 Aug 15, 2024
6786f0e
chore: updates
Aug 15, 2024
e4873d0
chore: updates
peri044 Aug 19, 2024
a725ce0
Merge branch 'main' into llm_examples_main
peri044 Aug 19, 2024
bb10de4
feat: lowering replace aten.full_like with aten.full
chohk88 Aug 12, 2024
1527aa0
chore: minor linting
chohk88 Aug 12, 2024
67e33c3
chore: updates
peri044 Aug 19, 2024
5627c1a
Merge branch 'llm_examples_main' of github.com:pytorch/TensorRT into …
peri044 Aug 19, 2024
7be8604
chore: updates
peri044 Aug 21, 2024
4d75a2e
Merge branch 'main' into llm_examples_main
peri044 Aug 21, 2024
0ab0dbf
feat: add fp32 accumulation option for matmul layer
peri044 Aug 21, 2024
3c815f8
chore: updates
Aug 28, 2024
5617c0a
chore: Bump TRT version to 10.3.0.26 (#3071)
zewenli98 Aug 24, 2024
213526e
chore: updates
peri044 Aug 30, 2024
c193593
chore : updates
peri044 Aug 30, 2024
0de0b16
chore: updates
peri044 Sep 24, 2024
a90191d
chore: rebase with main
peri044 Sep 24, 2024
71e33cb
chore: updates
peri044 Sep 26, 2024
4257b1e
chore: updates
peri044 Sep 30, 2024
619a39a
chore: updates
peri044 Oct 1, 2024
8c0b9c6
chore: trunc_fiv fix
peri044 Oct 7, 2024
b6261f9
chore: update result
peri044 Oct 7, 2024
ebdfe8f
fix: add model.half() for llama2
peri044 Oct 7, 2024
61ec948
chore: address review comments
peri044 Oct 8, 2024
dd27a54
chore: address review comments
peri044 Oct 8, 2024
b2e5244
chore: add docs
peri044 Oct 8, 2024
7ddd637
chore: updates
peri044 Oct 8, 2024
4529717
chore: sign bug fix
peri044 Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ User Guide
* :ref:`saving_models`
* :ref:`runtime`
* :ref:`using_dla`
* :ref:`mixed_precision`

.. toctree::
:caption: User Guide
Expand All @@ -48,6 +49,7 @@ User Guide
user_guide/saving_models
user_guide/runtime
user_guide/using_dla
user_guide/mixed_precision
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/vgg16_ptq
tutorials/_rendered_examples/dynamo/engine_caching_example
Expand Down Expand Up @@ -118,6 +120,8 @@ Tutorials
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/torch_export_gpt2
tutorials/_rendered_examples/dynamo/torch_export_llama2

Python API Documentation
------------------------
Expand Down
74 changes: 74 additions & 0 deletions docsrc/user_guide/mixed_precision.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
.. _mixed_precision:

Compile Mixed Precision models with Torch-TensorRT
====================================
.. currentmodule:: torch_tensorrt.dynamo

.. automodule:: torch_tensorrt.dynamo
:members:
:undoc-members:
:show-inheritance:

Consider the following Pytorch model which explicitly casts intermediate layer to run in FP16.

.. code-block:: python

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10,10)
self.linear2 = torch.nn.Linear(10,30).half()
self.linear3 = torch.nn.Linear(30,40)

def forward(self, x):
x = self.linear1(x)
x = x.to(torch.float16)
x = self.linear2(x)
x = x.to(torch.float32)
x = self.linear3(x)
return x


If we compile the above model using Torch-TensorRT, layer profiling logs indicate that all the layers are
run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance.

.. code-block:: python

inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
mod = MyModule().eval().cuda()
ep = torch.export.export(mod, tuple(inputs))
with torch_tensorrt.logging.debug():
trt_gm = torch_tensorrt.dynamo.compile(ep,
inputs=inputs,
debug=True)

# Debug log info
# Layers:
# Name: __myl_MulSum_myl0_0, LayerType: kgen, Inputs: [ { Name: __mye116_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }], TacticName: __myl_MulSum_0xfa6c1858aea1b13b03f90165d7149ec6, StreamId: 0, Metadata:
# Name: __myl_AddResMulSum_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye131_dconst, Dimensions: [10,30], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }, { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_AddResMulSum_0xb3915d7ebfe48be45b6d49083479e12f, StreamId: 0, Metadata:
# Name: __myl_AddResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye146_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_AddResMulSumAdd_0xcdd0085ad25f5f45ac5fafb72acbffd6, StreamId: 0, Metadata:


In order to respect the types specified by the user in the model (eg: in this case, ``linear2`` layer to run in FP16), users can enable
the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs

.. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions.

.. code-block:: python

inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
mod = MyModule().eval().cuda()
ep = torch.export.export(mod, tuple(inputs))
with torch_tensorrt.logging.debug():
trt_gm = torch_tensorrt.dynamo.compile(ep,
inputs=inputs,
use_explicit_typing=True
debug=True)

# Debug log info
# Layers:
# Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }, { Name: __mye112_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], TacticName: __myl_MulSumAddCas_0xacf8f5dd9be2f3e7bb09cdddeac6c936, StreamId: 0, Metadata:
# Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata:
# Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata:

Now the ``linear2`` layer runs in FP16 as shown in the above logs.
31 changes: 24 additions & 7 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
.. _torch_compile:

Dynamo / ``torch.compile``
----------------------------
Torch-TensorRT Examples
====================================

Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe
a number of ways you can leverage this backend to accelerate inference.
Please refer to the following examples which demonstrate the usage of different features of Torch-TensorRT. We also provide
examples of Torch-TensorRT compilation of select computer vision and language models.

* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
Dependencies
------------------------------------

Please install the following external dependencies (assuming you already have correct `torch`, `torch_tensorrt` and `tensorrt` libraries installed (`dependencies <https://github.com/pytorch/TensorRT?tab=readme-ov-file#dependencies>`_))

.. code-block:: python

pip install -r requirements.txt


Compiler Features
------------------------------------
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times
* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT

Model Zoo
------------------------------------
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
4 changes: 2 additions & 2 deletions examples/dynamo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cupy==13.1.0
torch>=2.4.0.dev20240503+cu121
torch-tensorrt>=2.4.0.dev20240503+cu121
triton==2.3.0
diffusers==0.30.3
transformers==4.44.2
30 changes: 22 additions & 8 deletions examples/dynamo/torch_export_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
# CPU is used here so that GPU memory is reserved for TRT compilation.
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
).eval()
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
)
.eval()
.half()
)

# %%
# Tokenize a sample input prompt and get pytorch model outputs
Expand All @@ -48,6 +52,10 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Export the GPT2 model into an ExportedProgram which is input of TRT compilation
# To compile the model in FP16, we do the following
# 1) Cast the model to FP16 via model.half()
# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation
# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch)
gpt2_ep = export_llm(model, input_ids, max_seq_len=1024)
trt_model = torch_tensorrt.dynamo.compile(
gpt2_ep,
Expand All @@ -56,6 +64,8 @@
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_explicit_typing=True,
use_fp32_acc=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
Expand All @@ -81,6 +91,10 @@
# %%
# The output sentences should look like
# =============================
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# Pytorch model generated text: What is parallel programming ?

# The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# TensorRT model generated text: What is parallel programming ?

# The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
21 changes: 15 additions & 6 deletions examples/dynamo/torch_export_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
# CPU is used here so that GPU memory is reserved for TRT compilation.
llama_path = "meta-llama/Llama-2-7b-chat-hf"
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(
llama_path, use_cache=False, attn_implementation="eager"
).eval()
model = (
AutoModelForCausalLM.from_pretrained(
llama_path, use_cache=False, attn_implementation="eager"
)
.eval()
.half()
)

tokenizer = AutoTokenizer.from_pretrained(llama_path)

Expand All @@ -45,15 +49,20 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Export the llama2 model into an ExportedProgram which is input of TRT compilation
# To compile the model in FP16, we do the following
# 1) Cast the model to FP16 via model.half()
# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation
# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch)
llama2_ep = export_llm(model, input_ids, max_seq_len=64)
trt_model = torch_tensorrt.dynamo.compile(
llama2_ep,
inputs=[input_ids],
enabled_precisions={torch.float32},
min_block_size=1,
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_explicit_typing=True,
use_fp32_acc=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
Expand Down Expand Up @@ -85,6 +94,6 @@
# %%
# The output sentences should look like
# =============================
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# Pytorch model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# TensorRT model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
28 changes: 27 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def compile(
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -158,6 +160,8 @@ def compile(
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -197,6 +201,20 @@ def compile(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions for x in {torch.float32, dtype.f32}
):
raise AssertionError(
f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
)

if use_fp32_acc:
logger.debug(
peri044 marked this conversation as resolved.
Show resolved Hide resolved
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
)

# Aliasing inputs to arg_inputs for better understanding
if not arg_inputs and not inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
Expand Down Expand Up @@ -232,7 +250,7 @@ def compile(
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
gm = post_lowering(gm)
gm = post_lowering(gm, use_fp32_acc=use_fp32_acc)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
logger.debug("Lowered Input graph: " + str(gm.graph))

engine_cache = None
Expand Down Expand Up @@ -281,6 +299,8 @@ def compile(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"use_explicit_typing": use_explicit_typing,
"use_fp32_acc": use_fp32_acc,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -520,6 +540,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator: object = None,
allow_shape_tensors: bool = False,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -578,6 +600,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -651,6 +675,8 @@ def convert_exported_program_to_serialized_trt_engine(
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"timing_cache_path": timing_cache_path,
"use_explicit_typing": use_explicit_typing,
"use_fp32_acc": use_fp32_acc,
}

exported_program = pre_export_lowering(exported_program)
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
CUSTOM_ENGINE_CACHE = None
USE_EXPLICIT_TYPING = False
USE_FP32_ACC = False


def default_device() -> Device:
Expand Down
Loading
Loading