Skip to content

Commit

Permalink
feat: Support aten.dot dynamo converter (#3043)
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu authored Aug 8, 2024
1 parent 8ecc809 commit fdaba9a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ def aten_ops_hard_sigmoid(


@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default, supports_dynamic_shapes=True)
Expand Down
14 changes: 8 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/matmul.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import Optional

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt import _enums
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
from torch_tensorrt.fx.types import TRTTensor

import tensorrt as trt
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcast,
get_trt_tensor,
set_layer_name,
)
from torch_tensorrt.dynamo.types import TRTTensor


def matrix_multiply(
Expand Down Expand Up @@ -43,7 +45,7 @@ def matrix_multiply(
other_matrix_op = trt.MatrixOperation.VECTOR

input, other = broadcast(
ctx.net, input, other, f"{name}_input", f"{name}_other", preset_diff
ctx, input, other, f"{name}_input", f"{name}_other", preset_diff
)
layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
set_layer_name(layer, target, name, source_ir)
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
aten.detach,
aten.diag_embed,
aten.diagonal_backward,
aten.dot,
aten.elu_backward,
aten.embedding_dense_backward,
aten.empty_like,
Expand Down
35 changes: 35 additions & 0 deletions tests/py/dynamo/conversion/test_matmul_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,41 @@


class TestMatMulConverter(DispatchTestCase):
@parameterized.expand(
[
(
"1_1",
(1,),
(1,),
),
(
"1_1",
(2,),
(2,),
),
(
"1_1",
(3,),
(3,),
),
]
)
def test_matmul_dot(self, _, input_shape, other_shape):
class MatMul(nn.Module):
def __init__(self):
super().__init__()
self.other = nn.Parameter(torch.randn(*other_shape))

def forward(self, input):
return torch.ops.aten.dot.default(input, self.other)

inputs = [torch.randn(*input_shape)]

self.run_test(
MatMul(),
inputs,
)

@parameterized.expand(
[
(
Expand Down

0 comments on commit fdaba9a

Please sign in to comment.