diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b4daaaff25..1d734bca03 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -953,7 +953,7 @@ def aten_ops_cumsum( ) -@dynamo_tensorrt_converter(torch.ops.aten.tile.default) +@dynamo_tensorrt_converter(torch.ops.aten.tile.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index eae0e24dcb..04eab08c47 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -457,6 +457,7 @@ def tile( dims: Sequence[int], ) -> TRTTensor: diff = len(dims) - len(input.shape) + has_dynamic_shape_input = has_dynamic_shape(input.shape) if diff > 0: # prepend 1 to input.shape new_shape = (1,) * diff + tuple(input.shape) @@ -467,10 +468,64 @@ def tile( # prepend 1 to dims dims = (1,) * -diff + tuple(dims) - shapes = [i * j for i, j in zip(input.shape, dims)] - starts = [0] * len(dims) - strides = [1] * len(dims) - layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides)) + starts = tuple([0] * len(dims)) + strides = tuple([1] * len(dims)) + # layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides)) + if not (has_dynamic_shape_input): + shapes = [i * j for i, j in zip(input.shape, dims)] + layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides)) + else: + shapes = [] + index = 0 + for i, j in zip(input.shape, dims): + if i == DYNAMIC_DIM: + i = get_shape( + ctx, target, source_ir, name + f"_input_{index}", input, index + ) + prod_shape = convert_binary_elementwise( + ctx, + target, + source_ir, + name + "_prod", + trt.ElementWiseOperation.PROD, + i, + j, + ) + shapes.append(prod_shape) + index = index + 1 + layer = ctx.net.add_slice( + input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims() + ) + shape_tensor = cat( + ctx, + target, + source_ir, + name + "_shape_concat", + tuple(shapes), + 0, + cast_dtype=trt.int32, + ) + start_tensor = cat( + ctx, + target, + source_ir, + name + "_start_concat", + starts, + 0, + cast_dtype=trt.int32, + ) + stride_tensor = cat( + ctx, + target, + source_ir, + name + "_stride_concat", + strides, + 0, + cast_dtype=trt.int32, + ) + layer.set_input(1, start_tensor) + layer.set_input(2, shape_tensor) + layer.set_input(3, stride_tensor) layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, name) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_tile_aten.py b/tests/py/dynamo/conversion/test_tile_aten.py index 5a7e98aa7d..47ac9c85c7 100644 --- a/tests/py/dynamo/conversion/test_tile_aten.py +++ b/tests/py/dynamo/conversion/test_tile_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -71,5 +72,49 @@ def forward(self, x): ) +class TestTileConverterDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ((3,), (3,), (6,), (1,)), + ((3,), (3,), (6,), (0,)), + ((3,), (3,), (6,), (2,)), + ((2,), (3,), (6,), (2, 2)), + ((2,), (3,), (6,), (0, 2)), + # 2d cases + ((3, 1), (3, 1), (6, 1), (0,)), + ((3, 1), (3, 1), (6, 1), (2,)), + ((2, 3), (2, 3), (4, 3), (2, 2)), + ((2, 3), (2, 3), (4, 3), (1, 0)), + ((2, 3), (2, 3), (4, 3), (0, 2)), + ((2, 3), (2, 3), (4, 3), (4, 2, 3)), + ((2, 3), (2, 3), (4, 3), (0, 0, 3)), + ((2, 3), (2, 3), (4, 3), (4, 2, 3, 1, 2)), + # 3d cases + ((4, 2, 3), (4, 2, 3), (6, 2, 3), (2,)), + ((4, 2, 3), (4, 2, 3), (6, 2, 3), (1, 2)), + ((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3)), + ((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4)), + ((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4, 5)), + ] + ) + def test_tile_input_dynamic(self, min_shape, opt_shape, max_shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=torch.float32, + ), + ] + self.run_test_with_dynamic_shape( + Tile(), + input_specs, + ) + + if __name__ == "__main__": run_tests()