From fdaba9a9eedbe58a1b9a88dde5bb8b9b01203e0e Mon Sep 17 00:00:00 2001 From: HolyWu Date: Fri, 9 Aug 2024 07:11:06 +0800 Subject: [PATCH] feat: Support `aten.dot` dynamo converter (#3043) --- .../dynamo/conversion/aten_ops_converters.py | 1 + .../dynamo/conversion/impl/matmul.py | 14 ++++---- .../dynamo/lowering/_decomposition_groups.py | 1 - .../py/dynamo/conversion/test_matmul_aten.py | 35 +++++++++++++++++++ 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e3aaaec175..98ec69eb44 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -548,6 +548,7 @@ def aten_ops_hard_sigmoid( @dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.bmm.default, supports_dynamic_shapes=True) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 5ea29622c8..77dd7ae6f5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -1,15 +1,17 @@ from typing import Optional +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor -from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name -from torch_tensorrt.fx.types import TRTTensor - -import tensorrt as trt +from torch_tensorrt.dynamo.conversion.converter_utils import ( + broadcast, + get_trt_tensor, + set_layer_name, +) +from torch_tensorrt.dynamo.types import TRTTensor def matrix_multiply( @@ -43,7 +45,7 @@ def matrix_multiply( other_matrix_op = trt.MatrixOperation.VECTOR input, other = broadcast( - ctx.net, input, other, f"{name}_input", f"{name}_other", preset_diff + ctx, input, other, f"{name}_input", f"{name}_other", preset_diff ) layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) set_layer_name(layer, target, name, source_ir) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 0a1688b295..ae3e7e1ffa 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -33,7 +33,6 @@ aten.detach, aten.diag_embed, aten.diagonal_backward, - aten.dot, aten.elu_backward, aten.embedding_dense_backward, aten.empty_like, diff --git a/tests/py/dynamo/conversion/test_matmul_aten.py b/tests/py/dynamo/conversion/test_matmul_aten.py index ea39ea06f8..cf1fa36e82 100644 --- a/tests/py/dynamo/conversion/test_matmul_aten.py +++ b/tests/py/dynamo/conversion/test_matmul_aten.py @@ -8,6 +8,41 @@ class TestMatMulConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + "1_1", + (1,), + (1,), + ), + ( + "1_1", + (2,), + (2,), + ), + ( + "1_1", + (3,), + (3,), + ), + ] + ) + def test_matmul_dot(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.other = nn.Parameter(torch.randn(*other_shape)) + + def forward(self, input): + return torch.ops.aten.dot.default(input, self.other) + + inputs = [torch.randn(*input_shape)] + + self.run_test( + MatMul(), + inputs, + ) + @parameterized.expand( [ (