Skip to content

Commit

Permalink
fix: Allow rank differences in aten.expand (#2234)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Aug 22, 2023
1 parent 1133432 commit 56b8950
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 28 deletions.
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,21 @@ def aten_ops_clone(
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
def aten_ops_expand(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.expand(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
63 changes: 38 additions & 25 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
from torch_tensorrt.fx.converters.converter_utils import (
broadcast,
get_positive_dim,
get_trt_tensor,
has_dynamic_shape,
prepend_ones,
set_layer_name,
)
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor

Expand Down Expand Up @@ -65,33 +65,46 @@ def expand(
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
sizes: Shape,
input_t: TRTTensor,
shape: Shape,
) -> TRTTensor:
shape = list(sizes)

input_val = get_trt_tensor(network, input, f"{name}_input")
if not isinstance(input_t, TRTTensor):
raise RuntimeError(
f"expand received input {input_t} that is not a TensorRT ITensor"
)

if network.has_implicit_batch_dimension:
shape = shape[1:]
shape_rank = len(shape)
initial_tensor_rank = len(input_t.shape)

ranks = len(input_val.shape)
# TRT does not support different dimension size
# though this condition is not seen in the case of bmm
# where input_t and shape dimensions are not equal
assert len(shape) >= ranks
if len(shape) != ranks:
shape_tuple = tuple([0] * len(shape))
shape_tensor = get_trt_tensor(network, input, f"{name}_shape")
input_val, shape_tensor = broadcast(
network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val"
# If the rank of the input tensor is less than the shape's rank, pad with ones
if initial_tensor_rank < shape_rank:
input_t = prepend_ones(
network,
input_t,
name + "_expand_broadcast",
shape_rank - initial_tensor_rank,
)
ranks = len(shape)
# If the rank of the input tensor is more than the shape's rank, raise error
elif initial_tensor_rank > shape_rank:
raise RuntimeError(
f"expand called with {shape_rank}-dimensional shape on Tensor with {len(shape)} dimensions. "
"Cannot expand to shape with rank smaller than original tensor."
)

# After the above padding, the shape and tensor rank must be equal
assert len(input_t.shape) == shape_rank

# -1 denotes taking the shape from the original input tensor
shape = tuple(
[input_t.shape[i] if shape[i] == -1 else shape[i] for i in range(shape_rank)]
)

inshape = tuple(input_val.shape)
shape_t = tuple(shape)
start = tuple([0] * ranks)
# Establish the desired output shape, strides, and starting indices
input_tensor_shape = tuple(input_t.shape)
start = tuple([0] * shape_rank)
stride = tuple(
[int(i == o) for i, o in zip(inshape, shape)]
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
) # stride == 1 if dimensions match, 0 otherwise
return slice(network, target, source_ir, name, input_val, start, shape_t, stride)
layer = network.add_slice(input_t, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def unsqueeze(
)

dim = cast(int, dim)
input_shape = input_val.shape

input_shape_size = (
len(input_val.shape) + 1
if network.has_implicit_batch_dimension
Expand All @@ -46,5 +46,5 @@ def unsqueeze(
layer.reshape_dims = (
tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
)
set_layer_name(layer, target, name)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
3 changes: 2 additions & 1 deletion tests/py/dynamo/converters/test_expand_aten.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
from harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from harness import DispatchTestCase


class TestExpandConverter(DispatchTestCase):
Expand All @@ -12,6 +12,7 @@ class TestExpandConverter(DispatchTestCase):
("3d_dim", (2, 3, 4), (2, 1, 1)),
("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)),
("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)),
("different_ranks", (2, 3, -1, -1), (1, 5, 7)),
]
)
def test_expand(self, _, sizes, init_size):
Expand Down

0 comments on commit 56b8950

Please sign in to comment.