Skip to content

Commit

Permalink
feat: dynamic shape support for aten.select.int (#2990)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 authored Jul 31, 2024
1 parent 5bd948f commit c99c966
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 48 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def aten_ops_scatter(
)


@dynamo_tensorrt_converter(torch.ops.aten.select.int)
@dynamo_tensorrt_converter(torch.ops.aten.select.int, supports_dynamic_shapes=True)
def aten_ops_select(
ctx: ConversionContext,
target: Target,
Expand Down
36 changes: 9 additions & 27 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Sequence, Union, cast
from typing import Optional, Sequence, Union

import numpy as np
import tensorrt as trt
Expand All @@ -21,7 +21,7 @@
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.fx.types import Shape, TRTTensor
from torch_tensorrt.fx.types import TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -32,8 +32,8 @@ def select(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Shape,
index: Shape,
dim: int,
index: int,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
Expand All @@ -42,31 +42,13 @@ def select(
)

ranks = len(input.shape)
dim = get_positive_dim(cast(int, dim), ranks)
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
index = index
dim = get_positive_dim(dim, ranks)

if index >= input.shape[dim]:
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
)
output_shape = list(input.shape)
output_shape[dim] = 1
if dynamic_shape > 0:
output_shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, output_shape, input
)
index_value = np.array(index, dtype=np.int32)
indices_tensor = ctx.net.add_constant(
index_value.shape, to_numpy(index_value)
).get_output(0)
indices_tensor = get_trt_tensor(
ctx, np.array(index, dtype=np.int32), f"{name}_indices_tensor"
)
layer = ctx.net.add_gather(input, indices_tensor, dim)
out = layer.get_output(0)
if len(out.shape) != 1:
layer = ctx.net.add_shuffle(out)

return layer.get_output(0)


Expand Down
76 changes: 56 additions & 20 deletions tests/py/dynamo/conversion/test_select_aten.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
Expand All @@ -9,11 +10,11 @@
class TestSelectConverterOne(DispatchTestCase):
@parameterized.expand(
[
("select_dim_index", 1, 0),
("dim_index", 1, 0),
]
)
def test_select(self, _, dim, index):
class TestModule(torch.nn.Module):
def test_select_2d(self, _, dim, index):
class select(nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -22,19 +23,17 @@ def forward(self, input):

input = [torch.randn(1, 2)]
self.run_test(
TestModule(),
select(),
input,
)


class TestSelectConverterTwo(DispatchTestCase):
@parameterized.expand(
[
("select_dim_index", 1, 0),
("dim_index", 1, 0),
]
)
def test_select(self, _, dim, index):
class TestModule(torch.nn.Module):
def test_select_4d(self, _, dim, index):
class select(nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -43,33 +42,70 @@ def forward(self, input):

input = [torch.randn(4, 4, 4, 4)]
self.run_test(
TestModule(),
select(),
input,
)


class TestSelectConverterWithDynamicShape(DispatchTestCase):
@parameterized.expand(
[
("select_dim_index", 1, 0),
(
"partial_dynamic_static_dim",
(1, 1, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
2,
0,
),
(
"partial_dynamic_dynamic_dim",
(1, 1, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
1,
1,
),
(
"fully_dynamic",
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
1,
1,
),
(
"fully_dynamic_neg_dim",
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
-1,
1,
),
]
)
def test_select_with_dynamic_shape(self, _, dim, index):
class TestModule(torch.nn.Module):
def test_dynamic_shape_select(
self, _, min_shape, opt_shape, max_shape, type, dim, index
):
class select(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.select.int(input, dim, index)

input_spec = [
input_specs = [
Input(
shape=(-1, 3, 3),
dtype=torch.float32,
shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(TestModule(), input_spec)

self.run_test_with_dynamic_shape(select(), input_specs)


if __name__ == "__main__":
Expand Down

0 comments on commit c99c966

Please sign in to comment.