From 56b8950d9b488b6d1ea4757e4f2a344a6799b98c Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Tue, 22 Aug 2023 12:19:43 -0700 Subject: [PATCH] fix: Allow rank differences in `aten.expand` (#2234) --- .../dynamo/conversion/aten_ops_converters.py | 18 ++++++ .../dynamo/conversion/impl/slice/ops.py | 63 +++++++++++-------- .../dynamo/conversion/impl/unsqueeze.py | 4 +- .../py/dynamo/converters/test_expand_aten.py | 3 +- 4 files changed, 60 insertions(+), 28 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 75a7782354..d9b49fafbe 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -420,3 +420,21 @@ def aten_ops_clone( name, args[0], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.expand.default) +def aten_ops_expand( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.expand( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 3835253219..28ed76169e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -5,10 +5,10 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion.impl.slice.base import slice from torch_tensorrt.fx.converters.converter_utils import ( - broadcast, get_positive_dim, - get_trt_tensor, has_dynamic_shape, + prepend_ones, + set_layer_name, ) from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor @@ -65,33 +65,46 @@ def expand( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, - sizes: Shape, + input_t: TRTTensor, + shape: Shape, ) -> TRTTensor: - shape = list(sizes) - - input_val = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(input_t, TRTTensor): + raise RuntimeError( + f"expand received input {input_t} that is not a TensorRT ITensor" + ) - if network.has_implicit_batch_dimension: - shape = shape[1:] + shape_rank = len(shape) + initial_tensor_rank = len(input_t.shape) - ranks = len(input_val.shape) - # TRT does not support different dimension size - # though this condition is not seen in the case of bmm - # where input_t and shape dimensions are not equal - assert len(shape) >= ranks - if len(shape) != ranks: - shape_tuple = tuple([0] * len(shape)) - shape_tensor = get_trt_tensor(network, input, f"{name}_shape") - input_val, shape_tensor = broadcast( - network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val" + # If the rank of the input tensor is less than the shape's rank, pad with ones + if initial_tensor_rank < shape_rank: + input_t = prepend_ones( + network, + input_t, + name + "_expand_broadcast", + shape_rank - initial_tensor_rank, ) - ranks = len(shape) + # If the rank of the input tensor is more than the shape's rank, raise error + elif initial_tensor_rank > shape_rank: + raise RuntimeError( + f"expand called with {shape_rank}-dimensional shape on Tensor with {len(shape)} dimensions. " + "Cannot expand to shape with rank smaller than original tensor." + ) + + # After the above padding, the shape and tensor rank must be equal + assert len(input_t.shape) == shape_rank + + # -1 denotes taking the shape from the original input tensor + shape = tuple( + [input_t.shape[i] if shape[i] == -1 else shape[i] for i in range(shape_rank)] + ) - inshape = tuple(input_val.shape) - shape_t = tuple(shape) - start = tuple([0] * ranks) + # Establish the desired output shape, strides, and starting indices + input_tensor_shape = tuple(input_t.shape) + start = tuple([0] * shape_rank) stride = tuple( - [int(i == o) for i, o in zip(inshape, shape)] + [int(i == o) for i, o in zip(input_tensor_shape, shape)] ) # stride == 1 if dimensions match, 0 otherwise - return slice(network, target, source_ir, name, input_val, start, shape_t, stride) + layer = network.add_slice(input_t, start=start, shape=shape, stride=stride) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index b16fee1eec..9929e59d86 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -27,7 +27,7 @@ def unsqueeze( ) dim = cast(int, dim) - input_shape = input_val.shape + input_shape_size = ( len(input_val.shape) + 1 if network.has_implicit_batch_dimension @@ -46,5 +46,5 @@ def unsqueeze( layer.reshape_dims = ( tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:] ) - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/tests/py/dynamo/converters/test_expand_aten.py b/tests/py/dynamo/converters/test_expand_aten.py index bb5b93304a..16283c41a6 100644 --- a/tests/py/dynamo/converters/test_expand_aten.py +++ b/tests/py/dynamo/converters/test_expand_aten.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn +from harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from harness import DispatchTestCase class TestExpandConverter(DispatchTestCase): @@ -12,6 +12,7 @@ class TestExpandConverter(DispatchTestCase): ("3d_dim", (2, 3, 4), (2, 1, 1)), ("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)), ("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)), + ("different_ranks", (2, 3, -1, -1), (1, 5, 7)), ] ) def test_expand(self, _, sizes, init_size):