Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement FP32 accumulation for matmul #3110

Merged
merged 81 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 75 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
2ea181a
chore: add gpt2 example
peri044 Jun 13, 2024
37b65a5
chore: add llama2 example
peri044 Jun 13, 2024
bd12b12
Merge branch 'main' into llm_examples_main
peri044 Jun 13, 2024
4a9f73e
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
0387d0b
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
6193939
chore: updates
peri044 Jun 14, 2024
9d3296e
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
84fc49c
Merge branch 'main' into llm_examples_main
peri044 Jun 18, 2024
ff17d91
chore: rebase
peri044 Jun 18, 2024
8e6ba26
Merge branch 'llm_examples_main' of github.com:pytorch/TensorRT into …
peri044 Jun 24, 2024
67ec408
Merge branch 'main' into llm_examples_main
peri044 Jun 25, 2024
9af8e39
chore: remove aten.full decomposition
peri044 Jun 25, 2024
50d4096
chore: fix expand DS support
peri044 Jun 25, 2024
59febf5
chore: minor fix
peri044 Jun 26, 2024
c3e4382
chore: updates
peri044 Jun 26, 2024
0673db4
chore: add testcase
peri044 Jun 26, 2024
0b62f8f
Merge branch 'main' into full
peri044 Jun 26, 2024
54f6410
Merge branch 'full' into fix_expand_ds
peri044 Jun 26, 2024
ae3d6b2
Merge branch 'fix_expand_ds' into llm_examples_main
peri044 Jun 26, 2024
4464fd5
chore: updates
peri044 Jun 26, 2024
63b13cf
chore: updates
peri044 Jun 28, 2024
3d10b92
Merge branch 'main' into llm_examples_main
peri044 Jun 28, 2024
e97a94f
chore: updates
peri044 Jul 10, 2024
4f503a8
chore: updates
peri044 Jul 11, 2024
5ecf63e
chore: rebase
peri044 Jul 11, 2024
0d00d8c
chore: updates
peri044 Jul 11, 2024
8099003
chore: updates
peri044 Jul 11, 2024
457f706
chore: updates
peri044 Jul 11, 2024
ce3b2f8
chore: updates
peri044 Jul 11, 2024
d8acadc
chore: updates
peri044 Jul 12, 2024
262c87d
chore: updates
peri044 Jul 12, 2024
bb94dfd
chore: rebase
peri044 Jul 17, 2024
736b839
chore: updates
peri044 Jul 17, 2024
313380e
chore: bug fixes
peri044 Jul 18, 2024
1057d83
chore: updates
peri044 Jul 19, 2024
bfd0cf2
chore: fixes
peri044 Jul 20, 2024
17ddb31
chore: updates
peri044 Jul 20, 2024
88be4fa
chore: add torch compile gpt2 example
peri044 Jul 22, 2024
df825ab
chore: updates
peri044 Jul 22, 2024
ff07295
chore: add timing calculation
peri044 Jul 24, 2024
857b0aa
Merge branch 'main' into llm_examples_main
peri044 Jul 24, 2024
8fae56b
Merge branch 'main' into llm_examples_main
peri044 Jul 29, 2024
d483718
chore: rebase
peri044 Jul 31, 2024
397e4bc
Merge branch 'main' into llm_examples_main
peri044 Aug 5, 2024
6c9b9fe
chore: updates
peri044 Aug 5, 2024
6313b1c
chore: updates
peri044 Aug 9, 2024
d608cc5
chore: rebase
peri044 Aug 9, 2024
1327782
chore: rebase fixes
peri044 Aug 9, 2024
0980778
chore: updates
peri044 Aug 9, 2024
94b2ba1
chore: updates
peri044 Aug 9, 2024
2b1db29
chore: updates
peri044 Aug 9, 2024
9f606fc
chore: updates
peri044 Aug 9, 2024
0cf23be
Merge branch 'main' into llm_examples_main
peri044 Aug 14, 2024
3228c57
chore: Update perf tooling with support for HF models (#3034)
peri044 Aug 15, 2024
6786f0e
chore: updates
Aug 15, 2024
e4873d0
chore: updates
peri044 Aug 19, 2024
a725ce0
Merge branch 'main' into llm_examples_main
peri044 Aug 19, 2024
bb10de4
feat: lowering replace aten.full_like with aten.full
chohk88 Aug 12, 2024
1527aa0
chore: minor linting
chohk88 Aug 12, 2024
67e33c3
chore: updates
peri044 Aug 19, 2024
5627c1a
Merge branch 'llm_examples_main' of github.com:pytorch/TensorRT into …
peri044 Aug 19, 2024
7be8604
chore: updates
peri044 Aug 21, 2024
4d75a2e
Merge branch 'main' into llm_examples_main
peri044 Aug 21, 2024
0ab0dbf
feat: add fp32 accumulation option for matmul layer
peri044 Aug 21, 2024
3c815f8
chore: updates
Aug 28, 2024
5617c0a
chore: Bump TRT version to 10.3.0.26 (#3071)
zewenli98 Aug 24, 2024
213526e
chore: updates
peri044 Aug 30, 2024
c193593
chore : updates
peri044 Aug 30, 2024
0de0b16
chore: updates
peri044 Sep 24, 2024
a90191d
chore: rebase with main
peri044 Sep 24, 2024
71e33cb
chore: updates
peri044 Sep 26, 2024
4257b1e
chore: updates
peri044 Sep 30, 2024
619a39a
chore: updates
peri044 Oct 1, 2024
8c0b9c6
chore: trunc_fiv fix
peri044 Oct 7, 2024
b6261f9
chore: update result
peri044 Oct 7, 2024
ebdfe8f
fix: add model.half() for llama2
peri044 Oct 7, 2024
61ec948
chore: address review comments
peri044 Oct 8, 2024
dd27a54
chore: address review comments
peri044 Oct 8, 2024
b2e5244
chore: add docs
peri044 Oct 8, 2024
7ddd637
chore: updates
peri044 Oct 8, 2024
4529717
chore: sign bug fix
peri044 Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ Tutorials
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/torch_export_gpt2
tutorials/_rendered_examples/dynamo/torch_export_llama2

Python API Documentation
------------------------
Expand Down
31 changes: 24 additions & 7 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
.. _torch_compile:

Dynamo / ``torch.compile``
----------------------------
Torch-TensorRT Examples
====================================

Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe
a number of ways you can leverage this backend to accelerate inference.
Please refer to the following examples which demonstrate the usage of different features of Torch-TensorRT. We also provide
examples of Torch-TensorRT compilation of select computer vision and language models.

* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
Dependencies
------------------------------------

Please install the following external dependencies (assuming you already have correct `torch`, `torch_tensorrt` and `tensorrt` libraries installed (`dependencies <https://github.com/pytorch/TensorRT?tab=readme-ov-file#dependencies>`_))

.. code-block:: python

pip install -r requirements.txt


Compiler Features
------------------------------------
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times
* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT

Model Zoo
------------------------------------
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
4 changes: 2 additions & 2 deletions examples/dynamo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cupy==13.1.0
torch>=2.4.0.dev20240503+cu121
torch-tensorrt>=2.4.0.dev20240503+cu121
triton==2.3.0
diffusers==0.30.3
transformers==4.44.2
26 changes: 18 additions & 8 deletions examples/dynamo/torch_export_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
# CPU is used here so that GPU memory is reserved for TRT compilation.
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
).eval()
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
)
.eval()
.half()
)

# %%
# Tokenize a sample input prompt and get pytorch model outputs
Expand All @@ -56,6 +60,8 @@
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_strong_types=True,
peri044 marked this conversation as resolved.
Show resolved Hide resolved
use_fp32_acc=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
Expand All @@ -81,6 +87,10 @@
# %%
# The output sentences should look like
# =============================
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# Pytorch model generated text: What is parallel programming ?

# The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# TensorRT model generated text: What is parallel programming ?

# The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
7 changes: 4 additions & 3 deletions examples/dynamo/torch_export_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@
llama2_ep,
inputs=[input_ids],
enabled_precisions={torch.float32},
min_block_size=1,
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_strong_types=True,
use_fp32_acc=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
Expand Down Expand Up @@ -85,6 +86,6 @@
# %%
# The output sentences should look like
# =============================
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# Pytorch model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# TensorRT model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
28 changes: 27 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def compile(
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
use_strong_types: bool = _defaults.USE_STRONG_TYPES,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -158,6 +160,8 @@ def compile(
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
use_strong_types (bool): Enable strong typing in TensorRT compilation
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -197,6 +201,20 @@ def compile(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if use_strong_types:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions for x in {torch.float32, dtype.f32}
):
raise AssertionError(
f"When use_strong_types is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
)

if use_fp32_acc:
logger.debug(
peri044 marked this conversation as resolved.
Show resolved Hide resolved
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
)

# Aliasing inputs to arg_inputs for better understanding
if not arg_inputs and not inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
Expand Down Expand Up @@ -232,7 +250,7 @@ def compile(
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
gm = post_lowering(gm)
gm = post_lowering(gm, use_fp32_acc=use_fp32_acc)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
logger.debug("Lowered Input graph: " + str(gm.graph))

engine_cache = None
Expand Down Expand Up @@ -281,6 +299,8 @@ def compile(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"use_strong_types": use_strong_types,
"use_fp32_acc": use_fp32_acc,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -520,6 +540,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator: object = None,
allow_shape_tensors: bool = False,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
use_strong_types: bool = _defaults.USE_STRONG_TYPES,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -578,6 +600,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
use_strong_types (bool): Enable strong typing in TensorRT compilation
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -651,6 +675,8 @@ def convert_exported_program_to_serialized_trt_engine(
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"timing_cache_path": timing_cache_path,
"use_strong_types": use_strong_types,
"use_fp32_acc": use_fp32_acc,
}

exported_program = pre_export_lowering(exported_program)
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
CUSTOM_ENGINE_CACHE = None
USE_STRONG_TYPES = False
USE_FP32_ACC = False


def default_device() -> Device:
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
USE_PYTHON_RUNTIME,
USE_STRONG_TYPES,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
default_device,
Expand Down Expand Up @@ -78,6 +80,8 @@ class CompilationSettings:
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_strong_types (bool): Enable strong typing in TensorRT compilation
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -112,6 +116,8 @@ class CompilationSettings:
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
use_strong_types: bool = USE_STRONG_TYPES
use_fp32_acc: bool = USE_FP32_ACC


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _pretraced_backend(

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = post_lowering(gm)
gm = post_lowering(gm, use_fp32_acc=settings.use_fp32_acc)

logger.debug("Lowered Input graph:\n " + str(gm.graph))

Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ def __init__(
self.builder = trt.Builder(self.logger)

flag = 0

# It is deprecated to not use this flag
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
flag |= EXPLICIT_BATCH
if compilation_settings.use_strong_types:
STRONGLY_TYPED = 1 << (int)(
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED
)
flag |= STRONGLY_TYPED
peri044 marked this conversation as resolved.
Show resolved Hide resolved

self.ctx = ConversionContext(
self.builder.create_network(flag), compilation_settings
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Any, List, Optional, Sequence

import tensorrt as trt
import torch
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
Expand All @@ -18,8 +19,6 @@
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down
6 changes: 2 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
broadcast_to_same_shape,
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.fx.converters.converter_utils import (
broadcast,
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
peri044 marked this conversation as resolved.
Show resolved Hide resolved


def get_python_op_from_trt_elementwise_op(
Expand Down Expand Up @@ -152,7 +150,7 @@ def convert_binary_elementwise(

if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
lhs_val, rhs_val = broadcast(
ctx.net, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
)
else:
lhs_val, rhs_val = broadcast_to_same_shape(
Expand Down
9 changes: 4 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.converters.converter_utils import broadcast
from torch_tensorrt.fx.types import TRTTensor

import tensorrt as trt


def trunc_div(
ctx: ConversionContext,
Expand Down Expand Up @@ -69,10 +66,12 @@ def trunc_div(
prod_output,
)

# TODO: This casting causes output divergence for llama2 in FP16.
# @apbose to investigate why this is needed and suggest alternatives.
# cast the sign_output back to int32 for trunc div
# This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32)
if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32):
sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name)
# if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32):
# sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name)

# Convert constant input into ITensor for UnaryOperation
if not isinstance(input, trt.tensorrt.ITensor):
Expand Down
13 changes: 9 additions & 4 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,21 @@ class ReduceOperation(Enum):
AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y))
AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y))

def __new__(cls, description, func):
def __new__(cls, description: Any, func: Any) -> Any:
obj = object.__new__(cls)
obj._value_ = auto()
obj.description = description
obj.func = func
return obj

def reduce_operation_with_scatter(
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
):
self,
operation_lhs: Any,
initial_tensor: torch.Tensor,
dim: int,
index_tensor: torch.Tensor,
src_tensor: torch.Tensor,
) -> Any:
scatter_tensor = None
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
scatter_tensor = torch.zeros_like(initial_tensor)
Expand Down Expand Up @@ -341,7 +346,7 @@ def scatter_reduce_decomposition(
scatter_count_tensor = torch.zeros_like(input_tensor)
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]
if include_self == False:
if not include_self:
raise AssertionError("include_self False for scatter reduce not yet supported")
for i in range(0, src_dim):
src_slice = torch.select(src_tensor, dim, i)
Expand Down
Loading
Loading