From d75f588e49ca241cf44bbdb8aad0ef3d9578da75 Mon Sep 17 00:00:00 2001 From: HolyWu Date: Tue, 3 Sep 2024 22:45:55 +0800 Subject: [PATCH] feat: Support `aten.gelu` dynamo converter (#3134) --- .../dynamo/conversion/aten_ops_converters.py | 20 +++++++- .../dynamo/conversion/impl/activation/ops.py | 25 +++++++++- .../dynamo/lowering/_decomposition_groups.py | 1 - tests/py/dynamo/conversion/test_gelu_aten.py | 46 +++++++++++++------ 4 files changed, 75 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1d734bca03..92dfddc44f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -19,7 +19,7 @@ get_positive_dim, is_only_operator_on_placeholder, ) -from torch_tensorrt.fx.types import TRTTensor +from torch_tensorrt.dynamo.types import TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -548,6 +548,24 @@ def aten_ops_hard_sigmoid( ) +@dynamo_tensorrt_converter(torch.ops.aten.gelu.default, supports_dynamic_shapes=True) +def aten_ops_gelu( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.gelu( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + kwargs.get("approximate", "none"), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py index f578351ef2..a563118526 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py @@ -7,7 +7,7 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation -from torch_tensorrt.fx.types import TRTTensor +from torch_tensorrt.dynamo.types import TRTTensor def relu( @@ -327,3 +327,26 @@ def thresholded_relu_fn(x: float) -> float: alpha=alpha, dyn_range_fn=thresholded_relu_dyn_range_fn, ) + + +def gelu( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + approximate: str, +) -> TRTTensor: + if approximate == "none": + operation_type = trt.ActivationType.GELU_ERF + elif approximate == "tanh": + operation_type = trt.ActivationType.GELU_TANH + + return convert_activation( + ctx, + target, + source_ir, + name, + operation_type, + input_val, + ) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index ae3e7e1ffa..a84a550a1e 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -42,7 +42,6 @@ aten.fill, aten.frac, aten._fused_moving_avg_obs_fq_helper, - aten.gelu, aten.gelu_backward, aten.glu_backward, aten.hardshrink, diff --git a/tests/py/dynamo/conversion/test_gelu_aten.py b/tests/py/dynamo/conversion/test_gelu_aten.py index df0a0eca5f..dac33a9ae6 100644 --- a/tests/py/dynamo/conversion/test_gelu_aten.py +++ b/tests/py/dynamo/conversion/test_gelu_aten.py @@ -1,49 +1,67 @@ -import pytest import torch import torch.nn as nn +from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input from .harness import DispatchTestCase -@pytest.mark.skip(reason="This test will be skipped.") -class TestGeLUConverter(DispatchTestCase): - def test_gelu(self): +class TestGELUConverter(DispatchTestCase): + @parameterized.expand( + [ + ("none",), + ("tanh",), + ] + ) + def test_gelu(self, approximate): class TestModule(nn.Module): def forward(self, x): - return torch.ops.aten.gelu.default(x) + return torch.ops.aten.gelu.default(x, approximate=approximate) inputs = [torch.randn(1, 10)] self.run_test(TestModule(), inputs) - def test_gelu_with_dynamic_shape(self): + @parameterized.expand( + [ + ("none",), + ("tanh",), + ] + ) + def test_gelu_with_dynamic_shape(self, approximate): class TestModule(nn.Module): def forward(self, x): - return torch.ops.aten.gelu.default(x) + return torch.ops.aten.gelu.default(x, approximate=approximate) input_specs = [ Input( - shape=(-1, -1, -1), + min_shape=(1, 1, 1), + opt_shape=(1, 2, 3), + max_shape=(3, 3, 3), dtype=torch.float32, - shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], ), ] self.run_test_with_dynamic_shape(TestModule(), input_specs) - def test_gelu_with_dynamic_shape_four_dimensions(self): + @parameterized.expand( + [ + ("none",), + ("tanh",), + ] + ) + def test_gelu_with_dynamic_shape_four_dimensions(self, approximate): class TestModule(nn.Module): def forward(self, x): - return torch.ops.aten.gelu.default(x) + return torch.ops.aten.gelu.default(x, approximate=approximate) input_specs = [ Input( - shape=(-1, -1, -1, -1), + min_shape=(1, 1, 1, 5), + opt_shape=(1, 2, 3, 5), + max_shape=(3, 3, 3, 5), dtype=torch.float32, - shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], ), ] - self.run_test_with_dynamic_shape(TestModule(), input_specs)