diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 98ec69eb44..7d13ee790f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2876,7 +2876,7 @@ def aten_ops_resize( @enforce_tensor_types({0: (TRTTensor,)}) -@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) +@dynamo_tensorrt_converter(torch.ops.aten.argmax.default, supports_dynamic_shapes=True) def aten_ops_argmax( ctx: ConversionContext, target: Target, @@ -2896,7 +2896,7 @@ def aten_ops_argmax( @enforce_tensor_types({0: (TRTTensor,)}) -@dynamo_tensorrt_converter(torch.ops.aten.argmin.default) +@dynamo_tensorrt_converter(torch.ops.aten.argmin.default, supports_dynamic_shapes=True) def aten_ops_argmin( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py index 007f248af1..3b6549d285 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/topk.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -11,9 +11,13 @@ get_axes_for_reduce_op, get_positive_dim, set_layer_name, + get_trt_tensor, + has_dynamic_shape, ) -from torch_tensorrt.dynamo.types import TRTTensor +from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise +from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from torch_tensorrt.dynamo.types import TRTTensor def argmax_argmin( @@ -34,12 +38,60 @@ def argmax_argmin( # 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2 # 3. normal cases, no additional handlings out = input + is_dynamic_present = has_dynamic_shape(input.shape) if dim is None: - new_shape = (*flatten_dims(input, 0, -1), 1) - out = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_flatten", input, new_shape - ) + if is_dynamic_present and len(input.shape) != 1: + multiplier = get_trt_tensor(ctx, 1, name + "_shape") + for i in range(0, len(input.shape)): + if input.shape[i] != DYNAMIC_DIM: + multiplier = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_shape_{i}", + trt.ElementWiseOperation.PROD, + multiplier, + input.shape[i], + ) + else: + multiplier = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_shape_{i}", + trt.ElementWiseOperation.PROD, + multiplier, + get_shape( + ctx, + target, + source_ir, + name + f"_shape_dim_stop_{i}", + input, + i, + ), + ) + # form shape tensor + new_shape_layer = ctx.net.add_concatenation( + [multiplier, get_trt_tensor(ctx, 1, name + "_one_shape")] + ) + set_layer_name( + new_shape_layer, target, name + "_new_shape_concat", source_ir + ) + concat_tensor = new_shape_layer.get_output(0) + + reshape_dynamic_layer = ctx.net.add_shuffle(input) + reshape_dynamic_layer.set_input(1, concat_tensor) + set_layer_name( + reshape_dynamic_layer, target, name + "_reshape_layer", source_ir + ) + out = reshape_dynamic_layer.get_output(0) + + else: + new_shape = (*flatten_dims(input, 0, -1), 1) + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_flatten", input, new_shape + ) elif len(input.shape) == 1: new_shape = (*input.shape, 1) out = impl.shuffle.reshape( diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 9503bb0784..39a2f8ccde 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -382,6 +382,7 @@ def run_test_with_dynamic_shape( use_example_tensors=True, pyt_inputs=None, propagate_shapes=False, + check_dtype=True, ): mod = self.generate_graph( mod, @@ -395,6 +396,14 @@ def run_test_with_dynamic_shape( # We replicate this behavior here compilation_settings = CompilationSettings(truncate_double=True) + if check_dtype: + output_dtypes = infer_module_output_dtypes( + mod, + input_specs, + compilation_settings.device, + truncate_double=compilation_settings.truncate_double, + ) + interp = TRTInterpreter( mod, input_specs, diff --git a/tests/py/dynamo/conversion/test_argmax_aten.py b/tests/py/dynamo/conversion/test_argmax_aten.py index a3f9f67b95..a936040f82 100644 --- a/tests/py/dynamo/conversion/test_argmax_aten.py +++ b/tests/py/dynamo/conversion/test_argmax_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -36,6 +37,43 @@ def forward(self, input): self.run_test(ArgMax(), input) + @parameterized.expand( + [ + # input dimension == 1 + ("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True), + ("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False), + # dim == None + ("dim_1_none_true", (1,), (3,), (3,), None, True), + ("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True), + ("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False), + # common cases + ("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True), + ("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False), + ("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True), + ("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False), + ("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True), + ] + ) + def test_argmax_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim): + class ArgMax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.argmax.default(input, dim, keep_dim) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + ArgMax(), + input_specs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_argmin_aten.py b/tests/py/dynamo/conversion/test_argmin_aten.py index f06284f394..c0290f943e 100644 --- a/tests/py/dynamo/conversion/test_argmin_aten.py +++ b/tests/py/dynamo/conversion/test_argmin_aten.py @@ -36,6 +36,43 @@ def forward(self, input): self.run_test(ArgMin(), input) + @parameterized.expand( + [ + # input dimension == 1 + ("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True), + ("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False), + # dim == None + ("dim_1_none_true", (1,), (3,), (3,), None, True), + ("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True), + ("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False), + # common cases + ("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True), + ("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False), + ("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True), + ("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False), + ("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True), + ] + ) + def test_argmin_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim): + class ArgMin(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.argmin.default(input, dim, keep_dim) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + ArgMin(), + input_specs, + ) + if __name__ == "__main__": run_tests()