Skip to content

Commit

Permalink
feat: Support aten.gelu dynamo converter (#3134)
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu authored Sep 3, 2024
1 parent 0f8f23d commit d75f588
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 17 deletions.
20 changes: 19 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_positive_dim,
is_only_operator_on_placeholder,
)
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.dynamo.types import TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -548,6 +548,24 @@ def aten_ops_hard_sigmoid(
)


@dynamo_tensorrt_converter(torch.ops.aten.gelu.default, supports_dynamic_shapes=True)
def aten_ops_gelu(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.activation.gelu(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
kwargs.get("approximate", "none"),
)


@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
Expand Down
25 changes: 24 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.dynamo.types import TRTTensor


def relu(
Expand Down Expand Up @@ -327,3 +327,26 @@ def thresholded_relu_fn(x: float) -> float:
alpha=alpha,
dyn_range_fn=thresholded_relu_dyn_range_fn,
)


def gelu(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
approximate: str,
) -> TRTTensor:
if approximate == "none":
operation_type = trt.ActivationType.GELU_ERF
elif approximate == "tanh":
operation_type = trt.ActivationType.GELU_TANH

return convert_activation(
ctx,
target,
source_ir,
name,
operation_type,
input_val,
)
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
aten.fill,
aten.frac,
aten._fused_moving_avg_obs_fq_helper,
aten.gelu,
aten.gelu_backward,
aten.glu_backward,
aten.hardshrink,
Expand Down
46 changes: 32 additions & 14 deletions tests/py/dynamo/conversion/test_gelu_aten.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,67 @@
import pytest
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


@pytest.mark.skip(reason="This test will be skipped.")
class TestGeLUConverter(DispatchTestCase):
def test_gelu(self):
class TestGELUConverter(DispatchTestCase):
@parameterized.expand(
[
("none",),
("tanh",),
]
)
def test_gelu(self, approximate):
class TestModule(nn.Module):
def forward(self, x):
return torch.ops.aten.gelu.default(x)
return torch.ops.aten.gelu.default(x, approximate=approximate)

inputs = [torch.randn(1, 10)]
self.run_test(TestModule(), inputs)

def test_gelu_with_dynamic_shape(self):
@parameterized.expand(
[
("none",),
("tanh",),
]
)
def test_gelu_with_dynamic_shape(self, approximate):
class TestModule(nn.Module):
def forward(self, x):
return torch.ops.aten.gelu.default(x)
return torch.ops.aten.gelu.default(x, approximate=approximate)

input_specs = [
Input(
shape=(-1, -1, -1),
min_shape=(1, 1, 1),
opt_shape=(1, 2, 3),
max_shape=(3, 3, 3),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)

def test_gelu_with_dynamic_shape_four_dimensions(self):
@parameterized.expand(
[
("none",),
("tanh",),
]
)
def test_gelu_with_dynamic_shape_four_dimensions(self, approximate):
class TestModule(nn.Module):
def forward(self, x):
return torch.ops.aten.gelu.default(x)
return torch.ops.aten.gelu.default(x, approximate=approximate)

input_specs = [
Input(
shape=(-1, -1, -1, -1),
min_shape=(1, 1, 1, 5),
opt_shape=(1, 2, 3, 5),
max_shape=(3, 3, 3, 5),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
),
]

self.run_test_with_dynamic_shape(TestModule(), input_specs)


Expand Down

0 comments on commit d75f588

Please sign in to comment.