Skip to content

Commit

Permalink
chore: dynamic shape support for any/sort/trunc ops (#3026)
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna authored Jul 31, 2024
1 parent 784fa57 commit 23b4f1e
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 9 deletions.
19 changes: 13 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2668,10 +2668,15 @@ def topk_validator(node: Node) -> bool:


def sort_validator(node: Node) -> bool:
shape = node.args[0].meta.get("tensor_meta").shape
meta_data = node.args[0].meta.get("tensor_meta")
if meta_data is None:
return False
shape = meta_data.shape
dim = node.args[1]
dim = get_positive_dim(dim, len(shape))
k = shape[dim]
if not isinstance(k, int):
return False
return topk_sort_validator(k)


Expand Down Expand Up @@ -3436,7 +3441,9 @@ def aten_ops_topk(


@dynamo_tensorrt_converter(
torch.ops.aten.sort.default, capability_validator=sort_validator
torch.ops.aten.sort.default,
capability_validator=sort_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
Expand All @@ -3461,7 +3468,7 @@ def aten_ops_sort(
)


@dynamo_tensorrt_converter(torch.ops.aten.trunc.default)
@dynamo_tensorrt_converter(torch.ops.aten.trunc.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -3537,9 +3544,9 @@ def aten_ops_remainder(
)


@dynamo_tensorrt_converter(torch.ops.aten.any.default)
@dynamo_tensorrt_converter(torch.ops.aten.any.dim)
@dynamo_tensorrt_converter(torch.ops.aten.any.dims)
@dynamo_tensorrt_converter(torch.ops.aten.any.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.any.dim, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.any.dims, supports_dynamic_shapes=True)
def aten_ops_any(
ctx: ConversionContext,
target: Target,
Expand Down
12 changes: 9 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
flatten_dims,
get_axes_for_reduce_op,
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM


def argmax_argmin(
Expand Down Expand Up @@ -155,9 +156,14 @@ def topk(
k,
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
)

# topk layer supports dynamic k value but we cannot dertermin supported dynamic topk value at
# compile time.
assert k != DYNAMIC_DIM, "k value cannot be dynamic!"

# TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements
# so here no matter sorted is True or False the returned the topk Tensor object is always sorted
set_layer_name(topk_layer, target, name, source_ir)
set_layer_name(topk_layer, target, f"{name}_topk", source_ir)

if return_indices:
return topk_layer.get_output(0), topk_layer.get_output(1)
Expand Down
148 changes: 148 additions & 0 deletions tests/py/dynamo/conversion/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -184,5 +185,152 @@ def forward(self, x):
)


class TestAnyConverterDynamic(DispatchTestCase):
@parameterized.expand(
[
(
"3d_dynamic_float",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
torch.float,
),
(
"2d_dynamic_int32",
(2, 2),
(2, 2),
(3, 2),
torch.int32,
),
(
"4d_dynamic_bool",
(1, 2, 1, 1),
(2, 2, 2, 2),
(2, 2, 4, 3),
torch.bool,
),
]
)
def test_any_dynamic(self, _, min_shape, opt_shape, max_shape, type):
class Any(nn.Module):
def forward(self, x):
return torch.ops.aten.any.default(x)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
Any(),
input_specs,
)

@parameterized.expand(
[
(
"3d_dynamic_dim_float",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
torch.float,
2,
True,
),
(
"4d_dynamic_dim_int32",
(1, 1, 4, 1),
(2, 2, 4, 2),
(2, 4, 4, 3),
torch.int32,
-2,
False,
),
(
"3d_dynamic_dim_bool",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
torch.bool,
0,
True,
),
]
)
def test_any_dynamic_dim(
self, _, min_shape, opt_shape, max_shape, type, dim, keep_dims
):
class AnyDim(nn.Module):
def forward(self, x):
return torch.ops.aten.any.dim(x, dim, keep_dims)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
AnyDim(),
input_specs,
)

@parameterized.expand(
[
(
"3d_dynamic_dims_float",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
torch.float,
[1, 2],
True,
),
(
"4d_dynamic_dims_int32",
(1, 1, 4, 1),
(2, 2, 4, 2),
(2, 4, 4, 3),
torch.int32,
[2, -1],
False,
),
(
"3d_dynamic_dims_bool",
(1, 4, 1),
(2, 4, 2),
(4, 4, 3),
torch.bool,
[0, 1, 2],
False,
),
]
)
def test_any_dynamic_dims(
self, _, min_shape, opt_shape, max_shape, type, dims, keep_dims
):
class AnyDims(nn.Module):
def forward(self, x):
return torch.ops.aten.any.dims(x, dims, keep_dims)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
AnyDims(),
input_specs,
)


if __name__ == "__main__":
run_tests()
51 changes: 51 additions & 0 deletions tests/py/dynamo/conversion/test_sort_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -32,5 +33,55 @@ def forward(self, x):
)


class TestSortConverterDynamic(DispatchTestCase):
@parameterized.expand(
[
(
"3d_dynamic_descending",
(2, 1, 4),
(3, 2, 4),
(3, 3, 4),
2,
True,
),
(
"4d_dynamic_ascending",
(2, 2, 1, 4),
(2, 2, 2, 4),
(3, 3, 2, 4),
3,
False,
),
(
"4d_dynamic_descending_neg_dim",
(1, 3, 1, 1),
(2, 3, 2, 2),
(3, 3, 2, 4),
-3,
True,
),
]
)
def test_sort_dynamic(self, _, min_shape, opt_shape, max_shape, dim, descending):
class Sort(nn.Module):
def forward(self, x):
return torch.ops.aten.sort.default(x, dim, descending)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float,
),
]
self.run_test_with_dynamic_shape(
Sort(),
input_specs,
output_dtypes=[torch.float, torch.int64],
use_dynamo_tracer=True,
)


if __name__ == "__main__":
run_tests()
44 changes: 44 additions & 0 deletions tests/py/dynamo/conversion/test_trunc_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -48,5 +49,48 @@ def forward(self, input):
)


class TestTruncConverterDynamic(DispatchTestCase):
@parameterized.expand(
[
(
"3d_dynamic_int32",
(1, 1, 1),
(2, 2, 2),
(3, 4, 5),
torch.int32,
False,
),
(
"3d_dynamic_float32",
(2, 1, 1),
(2, 2, 2),
(2, 4, 5),
torch.float32,
True,
),
]
)
def test_trunc_dynamic(
self, _, min_shape, opt_shape, max_shape, type, enable_passes
):
class Trunc(nn.Module):
def forward(self, input):
return torch.ops.aten.trunc.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
Trunc(),
input_specs,
enable_passes=enable_passes,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 23b4f1e

Please sign in to comment.