Skip to content

Commit

Permalink
dynamic shape argmax and argmin (#3009)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored Aug 8, 2024
1 parent fdaba9a commit baa2eb1
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 7 deletions.
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,7 +2876,7 @@ def aten_ops_resize(


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default, supports_dynamic_shapes=True)
def aten_ops_argmax(
ctx: ConversionContext,
target: Target,
Expand All @@ -2896,7 +2896,7 @@ def aten_ops_argmax(


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default)
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default, supports_dynamic_shapes=True)
def aten_ops_argmin(
ctx: ConversionContext,
target: Target,
Expand Down
62 changes: 57 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
get_axes_for_reduce_op,
get_positive_dim,
set_layer_name,
get_trt_tensor,
has_dynamic_shape,
)
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.dynamo.types import TRTTensor


def argmax_argmin(
Expand All @@ -34,12 +38,60 @@ def argmax_argmin(
# 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2
# 3. normal cases, no additional handlings
out = input
is_dynamic_present = has_dynamic_shape(input.shape)

if dim is None:
new_shape = (*flatten_dims(input, 0, -1), 1)
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_flatten", input, new_shape
)
if is_dynamic_present and len(input.shape) != 1:
multiplier = get_trt_tensor(ctx, 1, name + "_shape")
for i in range(0, len(input.shape)):
if input.shape[i] != DYNAMIC_DIM:
multiplier = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
input.shape[i],
)
else:
multiplier = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
get_shape(
ctx,
target,
source_ir,
name + f"_shape_dim_stop_{i}",
input,
i,
),
)
# form shape tensor
new_shape_layer = ctx.net.add_concatenation(
[multiplier, get_trt_tensor(ctx, 1, name + "_one_shape")]
)
set_layer_name(
new_shape_layer, target, name + "_new_shape_concat", source_ir
)
concat_tensor = new_shape_layer.get_output(0)

reshape_dynamic_layer = ctx.net.add_shuffle(input)
reshape_dynamic_layer.set_input(1, concat_tensor)
set_layer_name(
reshape_dynamic_layer, target, name + "_reshape_layer", source_ir
)
out = reshape_dynamic_layer.get_output(0)

else:
new_shape = (*flatten_dims(input, 0, -1), 1)
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_flatten", input, new_shape
)
elif len(input.shape) == 1:
new_shape = (*input.shape, 1)
out = impl.shuffle.reshape(
Expand Down
9 changes: 9 additions & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def run_test_with_dynamic_shape(
use_example_tensors=True,
pyt_inputs=None,
propagate_shapes=False,
check_dtype=True,
):
mod = self.generate_graph(
mod,
Expand All @@ -395,6 +396,14 @@ def run_test_with_dynamic_shape(
# We replicate this behavior here
compilation_settings = CompilationSettings(truncate_double=True)

if check_dtype:
output_dtypes = infer_module_output_dtypes(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

interp = TRTInterpreter(
mod,
input_specs,
Expand Down
38 changes: 38 additions & 0 deletions tests/py/dynamo/conversion/test_argmax_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -36,6 +37,43 @@ def forward(self, input):

self.run_test(ArgMax(), input)

@parameterized.expand(
[
# input dimension == 1
("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True),
("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False),
# dim == None
("dim_1_none_true", (1,), (3,), (3,), None, True),
("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True),
("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False),
# common cases
("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True),
("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False),
("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True),
("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False),
("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True),
]
)
def test_argmax_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim):
class ArgMax(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.argmax.default(input, dim, keep_dim)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
ArgMax(),
input_specs,
)


if __name__ == "__main__":
run_tests()
37 changes: 37 additions & 0 deletions tests/py/dynamo/conversion/test_argmin_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,43 @@ def forward(self, input):

self.run_test(ArgMin(), input)

@parameterized.expand(
[
# input dimension == 1
("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True),
("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False),
# dim == None
("dim_1_none_true", (1,), (3,), (3,), None, True),
("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True),
("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False),
# common cases
("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True),
("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False),
("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True),
("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False),
("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True),
]
)
def test_argmin_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim):
class ArgMin(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.argmin.default(input, dim, keep_dim)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
ArgMin(),
input_specs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit baa2eb1

Please sign in to comment.