Skip to content

Commit

Permalink
tile dynamic dim (#3085)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored Aug 29, 2024
1 parent 39f8255 commit ffa4f64
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 5 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def aten_ops_cumsum(
)


@dynamo_tensorrt_converter(torch.ops.aten.tile.default)
@dynamo_tensorrt_converter(torch.ops.aten.tile.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
63 changes: 59 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def tile(
dims: Sequence[int],
) -> TRTTensor:
diff = len(dims) - len(input.shape)
has_dynamic_shape_input = has_dynamic_shape(input.shape)
if diff > 0:
# prepend 1 to input.shape
new_shape = (1,) * diff + tuple(input.shape)
Expand All @@ -467,10 +468,64 @@ def tile(
# prepend 1 to dims
dims = (1,) * -diff + tuple(dims)

shapes = [i * j for i, j in zip(input.shape, dims)]
starts = [0] * len(dims)
strides = [1] * len(dims)
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
starts = tuple([0] * len(dims))
strides = tuple([1] * len(dims))
# layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
if not (has_dynamic_shape_input):
shapes = [i * j for i, j in zip(input.shape, dims)]
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
else:
shapes = []
index = 0
for i, j in zip(input.shape, dims):
if i == DYNAMIC_DIM:
i = get_shape(
ctx, target, source_ir, name + f"_input_{index}", input, index
)
prod_shape = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_prod",
trt.ElementWiseOperation.PROD,
i,
j,
)
shapes.append(prod_shape)
index = index + 1
layer = ctx.net.add_slice(
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
shape_tensor = cat(
ctx,
target,
source_ir,
name + "_shape_concat",
tuple(shapes),
0,
cast_dtype=trt.int32,
)
start_tensor = cat(
ctx,
target,
source_ir,
name + "_start_concat",
starts,
0,
cast_dtype=trt.int32,
)
stride_tensor = cat(
ctx,
target,
source_ir,
name + "_stride_concat",
strides,
0,
cast_dtype=trt.int32,
)
layer.set_input(1, start_tensor)
layer.set_input(2, shape_tensor)
layer.set_input(3, stride_tensor)
layer.mode = trt.SampleMode.WRAP
set_layer_name(layer, target, name)
return layer.get_output(0)
Expand Down
45 changes: 45 additions & 0 deletions tests/py/dynamo/conversion/test_tile_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -71,5 +72,49 @@ def forward(self, x):
)


class TestTileConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
[
((3,), (3,), (6,), (1,)),
((3,), (3,), (6,), (0,)),
((3,), (3,), (6,), (2,)),
((2,), (3,), (6,), (2, 2)),
((2,), (3,), (6,), (0, 2)),
# 2d cases
((3, 1), (3, 1), (6, 1), (0,)),
((3, 1), (3, 1), (6, 1), (2,)),
((2, 3), (2, 3), (4, 3), (2, 2)),
((2, 3), (2, 3), (4, 3), (1, 0)),
((2, 3), (2, 3), (4, 3), (0, 2)),
((2, 3), (2, 3), (4, 3), (4, 2, 3)),
((2, 3), (2, 3), (4, 3), (0, 0, 3)),
((2, 3), (2, 3), (4, 3), (4, 2, 3, 1, 2)),
# 3d cases
((4, 2, 3), (4, 2, 3), (6, 2, 3), (2,)),
((4, 2, 3), (4, 2, 3), (6, 2, 3), (1, 2)),
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3)),
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4)),
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4, 5)),
]
)
def test_tile_input_dynamic(self, min_shape, opt_shape, max_shape, dims):
class Tile(nn.Module):
def forward(self, x):
return torch.ops.aten.tile.default(x, dims)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float32,
),
]
self.run_test_with_dynamic_shape(
Tile(),
input_specs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit ffa4f64

Please sign in to comment.