Skip to content

Commit

Permalink
chore: dynamic shape support for flip ops (#3046)
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna authored Jul 31, 2024
1 parent 079d90e commit bda5978
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 4 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3571,7 +3571,7 @@ def aten_ops_pdist(
)


@dynamo_tensorrt_converter(torch.ops.aten.flip.default)
@dynamo_tensorrt_converter(torch.ops.aten.flip.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
33 changes: 30 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,24 +446,51 @@ def flip(
output_shape = list(input.shape)
stride_slice = []

dynamic_shape = has_dynamic_shape(input.shape)

shape = input.shape
rank = len(shape)
dims = get_positive_dim(dims, rank)

for i in range(rank):
if i in dims:
start_slice.append(shape[i] - 1)
if shape[i] == DYNAMIC_DIM:
dim = get_shape(
ctx, target, source_ir, f"{name}_shape_dim_{i}", input, i
)
last_element_index = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_sub_{i}", dim, 1
)
start_slice.append(last_element_index)
else:
start_slice.append(shape[i] - 1)
stride_slice.append(-1)
else:
start_slice.append(0)
stride_slice.append(1)

layer = ctx.net.add_slice(
input,
start=start_slice,
shape=output_shape,
start=[] if dynamic_shape else start_slice,
shape=[] if dynamic_shape else output_shape,
stride=stride_slice,
)
if dynamic_shape:
output_shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, f"{name}_shape", output_shape, input
)

start_slice_tensor = cat(
ctx,
target,
source_ir,
f"{name}_start_slice_concat",
start_slice,
0,
)
layer.set_input(1, start_slice_tensor)
layer.set_input(2, output_shape)

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)

Expand Down
53 changes: 53 additions & 0 deletions tests/py/dynamo/conversion/test_flip_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -33,5 +34,57 @@ def forward(self, x):
self.run_test(Flip(), inputs)


class TestFlipConverterDynamic(DispatchTestCase):
@parameterized.expand(
[
(
"3d_dynamic",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
[2, 1, 0],
),
(
"3d_dynamic_negative_dim",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
[-1, 1],
),
(
"4d_dynamic_static_dim",
(3, 1, 1, 1),
(3, 2, 1, 2),
(3, 2, 4, 5),
[0, 2, 3],
),
(
"3d_dynamic_no_dim",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
[],
),
]
)
def test_flip_dynamic(self, _, min_shape, opt_shape, max_shape, dims):
class Flip(nn.Module):
def forward(self, x):
return torch.ops.aten.flip.default(x, dims)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float,
),
]
self.run_test_with_dynamic_shape(
Flip(),
input_specs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit bda5978

Please sign in to comment.