From 23b4f1e9ec9014942466f5f29b43e6c92a336887 Mon Sep 17 00:00:00 2001 From: keehyun Date: Thu, 1 Aug 2024 06:22:33 +0900 Subject: [PATCH] chore: dynamic shape support for any/sort/trunc ops (#3026) --- .../dynamo/conversion/aten_ops_converters.py | 19 ++- .../dynamo/conversion/impl/topk.py | 12 +- tests/py/dynamo/conversion/test_any.py | 148 ++++++++++++++++++ tests/py/dynamo/conversion/test_sort_aten.py | 51 ++++++ tests/py/dynamo/conversion/test_trunc_aten.py | 44 ++++++ 5 files changed, 265 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index a522fc2049..4c5d6706ab 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2668,10 +2668,15 @@ def topk_validator(node: Node) -> bool: def sort_validator(node: Node) -> bool: - shape = node.args[0].meta.get("tensor_meta").shape + meta_data = node.args[0].meta.get("tensor_meta") + if meta_data is None: + return False + shape = meta_data.shape dim = node.args[1] dim = get_positive_dim(dim, len(shape)) k = shape[dim] + if not isinstance(k, int): + return False return topk_sort_validator(k) @@ -3436,7 +3441,9 @@ def aten_ops_topk( @dynamo_tensorrt_converter( - torch.ops.aten.sort.default, capability_validator=sort_validator + torch.ops.aten.sort.default, + capability_validator=sort_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { @@ -3461,7 +3468,7 @@ def aten_ops_sort( ) -@dynamo_tensorrt_converter(torch.ops.aten.trunc.default) +@dynamo_tensorrt_converter(torch.ops.aten.trunc.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -3537,9 +3544,9 @@ def aten_ops_remainder( ) -@dynamo_tensorrt_converter(torch.ops.aten.any.default) -@dynamo_tensorrt_converter(torch.ops.aten.any.dim) -@dynamo_tensorrt_converter(torch.ops.aten.any.dims) +@dynamo_tensorrt_converter(torch.ops.aten.any.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.any.dim, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.any.dims, supports_dynamic_shapes=True) def aten_ops_any( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py index 78dd25d5a1..007f248af1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/topk.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -10,9 +10,10 @@ flatten_dims, get_axes_for_reduce_op, get_positive_dim, + set_layer_name, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTTensor +from torch_tensorrt.dynamo.types import TRTTensor +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM def argmax_argmin( @@ -155,9 +156,14 @@ def topk( k, get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), ) + + # topk layer supports dynamic k value but we cannot dertermin supported dynamic topk value at + # compile time. + assert k != DYNAMIC_DIM, "k value cannot be dynamic!" + # TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements # so here no matter sorted is True or False the returned the topk Tensor object is always sorted - set_layer_name(topk_layer, target, name, source_ir) + set_layer_name(topk_layer, target, f"{name}_topk", source_ir) if return_indices: return topk_layer.get_output(0), topk_layer.get_output(1) diff --git a/tests/py/dynamo/conversion/test_any.py b/tests/py/dynamo/conversion/test_any.py index 29522145da..1d1fc634ef 100644 --- a/tests/py/dynamo/conversion/test_any.py +++ b/tests/py/dynamo/conversion/test_any.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -184,5 +185,152 @@ def forward(self, x): ) +class TestAnyConverterDynamic(DispatchTestCase): + @parameterized.expand( + [ + ( + "3d_dynamic_float", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + torch.float, + ), + ( + "2d_dynamic_int32", + (2, 2), + (2, 2), + (3, 2), + torch.int32, + ), + ( + "4d_dynamic_bool", + (1, 2, 1, 1), + (2, 2, 2, 2), + (2, 2, 4, 3), + torch.bool, + ), + ] + ) + def test_any_dynamic(self, _, min_shape, opt_shape, max_shape, type): + class Any(nn.Module): + def forward(self, x): + return torch.ops.aten.any.default(x) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + Any(), + input_specs, + ) + + @parameterized.expand( + [ + ( + "3d_dynamic_dim_float", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + torch.float, + 2, + True, + ), + ( + "4d_dynamic_dim_int32", + (1, 1, 4, 1), + (2, 2, 4, 2), + (2, 4, 4, 3), + torch.int32, + -2, + False, + ), + ( + "3d_dynamic_dim_bool", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + torch.bool, + 0, + True, + ), + ] + ) + def test_any_dynamic_dim( + self, _, min_shape, opt_shape, max_shape, type, dim, keep_dims + ): + class AnyDim(nn.Module): + def forward(self, x): + return torch.ops.aten.any.dim(x, dim, keep_dims) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + AnyDim(), + input_specs, + ) + + @parameterized.expand( + [ + ( + "3d_dynamic_dims_float", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + torch.float, + [1, 2], + True, + ), + ( + "4d_dynamic_dims_int32", + (1, 1, 4, 1), + (2, 2, 4, 2), + (2, 4, 4, 3), + torch.int32, + [2, -1], + False, + ), + ( + "3d_dynamic_dims_bool", + (1, 4, 1), + (2, 4, 2), + (4, 4, 3), + torch.bool, + [0, 1, 2], + False, + ), + ] + ) + def test_any_dynamic_dims( + self, _, min_shape, opt_shape, max_shape, type, dims, keep_dims + ): + class AnyDims(nn.Module): + def forward(self, x): + return torch.ops.aten.any.dims(x, dims, keep_dims) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + AnyDims(), + input_specs, + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_sort_aten.py b/tests/py/dynamo/conversion/test_sort_aten.py index 8382da0047..5f1258c6ac 100644 --- a/tests/py/dynamo/conversion/test_sort_aten.py +++ b/tests/py/dynamo/conversion/test_sort_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -32,5 +33,55 @@ def forward(self, x): ) +class TestSortConverterDynamic(DispatchTestCase): + @parameterized.expand( + [ + ( + "3d_dynamic_descending", + (2, 1, 4), + (3, 2, 4), + (3, 3, 4), + 2, + True, + ), + ( + "4d_dynamic_ascending", + (2, 2, 1, 4), + (2, 2, 2, 4), + (3, 3, 2, 4), + 3, + False, + ), + ( + "4d_dynamic_descending_neg_dim", + (1, 3, 1, 1), + (2, 3, 2, 2), + (3, 3, 2, 4), + -3, + True, + ), + ] + ) + def test_sort_dynamic(self, _, min_shape, opt_shape, max_shape, dim, descending): + class Sort(nn.Module): + def forward(self, x): + return torch.ops.aten.sort.default(x, dim, descending) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=torch.float, + ), + ] + self.run_test_with_dynamic_shape( + Sort(), + input_specs, + output_dtypes=[torch.float, torch.int64], + use_dynamo_tracer=True, + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_trunc_aten.py b/tests/py/dynamo/conversion/test_trunc_aten.py index 979ced17e2..211ddbf9d1 100644 --- a/tests/py/dynamo/conversion/test_trunc_aten.py +++ b/tests/py/dynamo/conversion/test_trunc_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -48,5 +49,48 @@ def forward(self, input): ) +class TestTruncConverterDynamic(DispatchTestCase): + @parameterized.expand( + [ + ( + "3d_dynamic_int32", + (1, 1, 1), + (2, 2, 2), + (3, 4, 5), + torch.int32, + False, + ), + ( + "3d_dynamic_float32", + (2, 1, 1), + (2, 2, 2), + (2, 4, 5), + torch.float32, + True, + ), + ] + ) + def test_trunc_dynamic( + self, _, min_shape, opt_shape, max_shape, type, enable_passes + ): + class Trunc(nn.Module): + def forward(self, input): + return torch.ops.aten.trunc.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + Trunc(), + input_specs, + enable_passes=enable_passes, + ) + + if __name__ == "__main__": run_tests()