Skip to content

Commit

Permalink
feat: support tile dynamo converter (#2402)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Nov 1, 2023
1 parent 59a4910 commit 5578763
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 2 deletions.
25 changes: 24 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,30 @@ def aten_ops_cumsum(
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_tile(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.tile(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
30 changes: 29 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Optional
from typing import Optional, Sequence

import numpy as np
import tensorrt as trt
Expand Down Expand Up @@ -203,3 +203,31 @@ def cumsum(
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
loop_output.set_input(1, trip_limit)
return loop_output.get_output(0)


def tile(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dims: Sequence[int],
) -> TRTTensor:
diff = len(dims) - len(input.shape)
if diff > 0:
# prepend 1 to input.shape
new_shape = (1,) * diff + tuple(input.shape)
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape
)
elif diff < 0:
# prepend 1 to dims
dims = (1,) * -diff + tuple(dims)

shapes = [i * j for i, j in zip(input.shape, dims)]
starts = [0] * len(dims)
strides = [1] * len(dims)
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
layer.mode = trt.SampleMode.WRAP
set_layer_name(layer, target, name)
return layer.get_output(0)
75 changes: 75 additions & 0 deletions tests/py/dynamo/conversion/test_tile_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestTileConverter(DispatchTestCase):
@parameterized.expand(
[
((3,), (1,)),
((3,), (0,)),
((3,), (2,)),
((2,), (2, 2)),
((2,), (0, 2)),
]
)
def test_tile_1D(self, shape, dims):
class Tile(nn.Module):
def forward(self, x):
return torch.ops.aten.tile.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Tile(),
inputs,
)

@parameterized.expand(
[
((3, 1), (0,)),
((3, 1), (2,)),
((2, 3), (2, 2)),
((2, 3), (1, 0)),
((2, 3), (0, 2)),
((2, 3), (4, 2, 3)),
((2, 3), (0, 0, 3)),
((2, 3), (4, 2, 3, 1, 2)),
]
)
def test_tile_2D(self, shape, dims):
class Tile(nn.Module):
def forward(self, x):
return torch.ops.aten.tile.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Tile(),
inputs,
)

@parameterized.expand(
[
((4, 2, 3), (2,)),
((4, 2, 3), (1, 2)),
((1, 2, 3), (2, 3)),
((1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 3, 4, 5)),
]
)
def test_tile_3D(self, shape, dims):
class Tile(nn.Module):
def forward(self, x):
return torch.ops.aten.tile.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Tile(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 5578763

Please sign in to comment.