From bda59787804d3e31b4b96c012f80a6c992eb1edc Mon Sep 17 00:00:00 2001 From: keehyun Date: Thu, 1 Aug 2024 02:53:34 +0900 Subject: [PATCH] chore: dynamic shape support for flip ops (#3046) --- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../dynamo/conversion/impl/slice/ops.py | 33 ++++++++++-- tests/py/dynamo/conversion/test_flip_aten.py | 53 +++++++++++++++++++ 3 files changed, 84 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 2391a15ad1..b40b01ae88 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3571,7 +3571,7 @@ def aten_ops_pdist( ) -@dynamo_tensorrt_converter(torch.ops.aten.flip.default) +@dynamo_tensorrt_converter(torch.ops.aten.flip.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 776e2bec8e..cbd55d9d55 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -446,13 +446,24 @@ def flip( output_shape = list(input.shape) stride_slice = [] + dynamic_shape = has_dynamic_shape(input.shape) + shape = input.shape rank = len(shape) dims = get_positive_dim(dims, rank) for i in range(rank): if i in dims: - start_slice.append(shape[i] - 1) + if shape[i] == DYNAMIC_DIM: + dim = get_shape( + ctx, target, source_ir, f"{name}_shape_dim_{i}", input, i + ) + last_element_index = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_sub_{i}", dim, 1 + ) + start_slice.append(last_element_index) + else: + start_slice.append(shape[i] - 1) stride_slice.append(-1) else: start_slice.append(0) @@ -460,10 +471,26 @@ def flip( layer = ctx.net.add_slice( input, - start=start_slice, - shape=output_shape, + start=[] if dynamic_shape else start_slice, + shape=[] if dynamic_shape else output_shape, stride=stride_slice, ) + if dynamic_shape: + output_shape = get_shape_with_dynamic_shape( + ctx, target, source_ir, f"{name}_shape", output_shape, input + ) + + start_slice_tensor = cat( + ctx, + target, + source_ir, + f"{name}_start_slice_concat", + start_slice, + 0, + ) + layer.set_input(1, start_slice_tensor) + layer.set_input(2, output_shape) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_flip_aten.py b/tests/py/dynamo/conversion/test_flip_aten.py index aa4a2cd374..489aed6030 100644 --- a/tests/py/dynamo/conversion/test_flip_aten.py +++ b/tests/py/dynamo/conversion/test_flip_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -33,5 +34,57 @@ def forward(self, x): self.run_test(Flip(), inputs) +class TestFlipConverterDynamic(DispatchTestCase): + @parameterized.expand( + [ + ( + "3d_dynamic", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + [2, 1, 0], + ), + ( + "3d_dynamic_negative_dim", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + [-1, 1], + ), + ( + "4d_dynamic_static_dim", + (3, 1, 1, 1), + (3, 2, 1, 2), + (3, 2, 4, 5), + [0, 2, 3], + ), + ( + "3d_dynamic_no_dim", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + [], + ), + ] + ) + def test_flip_dynamic(self, _, min_shape, opt_shape, max_shape, dims): + class Flip(nn.Module): + def forward(self, x): + return torch.ops.aten.flip.default(x, dims) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=torch.float, + ), + ] + self.run_test_with_dynamic_shape( + Flip(), + input_specs, + ) + + if __name__ == "__main__": run_tests()