From 7323c86bca52c25a53fc886babb209e29aa24eb2 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Mon, 27 Nov 2023 10:30:24 -0800 Subject: [PATCH] fix: Bug in slice operator with default inputs (#2463) --- .../dynamo/conversion/aten_ops_converters.py | 16 +++++++-- .../dynamo/conversion/converter_utils.py | 10 +++--- .../dynamo/conversion/impl/slice/ops.py | 24 +++++-------- tests/py/dynamo/conversion/test_slice_aten.py | 35 ++++++++++++++----- 4 files changed, 53 insertions(+), 32 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b05713c360..0a1674ff94 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -687,6 +687,11 @@ def aten_ops_select( @dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) def aten_ops_slice( ctx: ConversionContext, target: Target, @@ -700,9 +705,9 @@ def aten_ops_slice( SourceIR.ATEN, name, args[0], - args[1], - args[2], - args[3], + args_bounds_check(args, 1, replacement=0), + args_bounds_check(args, 2, replacement=None), + args_bounds_check(args, 3, replacement=None), args_bounds_check(args, 4, replacement=1), ) @@ -900,6 +905,11 @@ def aten_ops_clone_copy_placeholder( @dynamo_tensorrt_converter(torch.ops.aten.expand.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) def aten_ops_expand( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index b65f95f0e5..f90c869c15 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np +import tensorrt as trt import torch from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target @@ -20,8 +21,6 @@ ) from torch_tensorrt.fx.types import TRTDataType, TRTTensor -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -339,8 +338,8 @@ def get_positive_dim( ) -> Union[int, Tuple[int, ...]]: """ Given an integer number or tuple that represents dimension(s) in the array, - transform it to a positive integer dim if it's negative. Otherwise, do - nothing. + transform it to a positive integer dim if it's negative. + Otherwise, truncate it to the dimension size Args: dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array. @@ -353,7 +352,8 @@ def get_positive_dim( def positive_dim(d: int) -> int: if d < 0: return d % dim_size - return d + else: + return min(d, dim_size) return ( positive_dim(dim) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 5619a4c2ba..91ac4a7042 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -27,15 +27,17 @@ def slice_op( # TODO: This should be slice not whatever is in base name: str, input: TRTTensor, dim: int, - start: int, - stop: int, + start: Optional[int], + stop: Optional[int], step: int, ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"slice_tensor received input {input} that is not part " - "of the TensorRT region!" - ) + # Special case for start being None + if start is None: + start = 0 + + # Special case for stop being None + if stop is None: + stop = input.shape[dim] dim = get_positive_dim(dim, len(input.shape)) start = get_positive_dim(start, input.shape[dim]) @@ -45,9 +47,6 @@ def slice_op( # TODO: This should be slice not whatever is in base # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!" - if stop == 2**63 - 1: - stop = input.shape[dim] - start_slice = [0] * len(input.shape) start_slice[dim] = start stride_slice = [1] * len(input.shape) @@ -68,11 +67,6 @@ def expand( input_t: TRTTensor, shape: Shape, ) -> TRTTensor: - if not isinstance(input_t, TRTTensor): - raise RuntimeError( - f"expand received input {input_t} that is not a TensorRT ITensor" - ) - shape_rank = len(shape) initial_tensor_rank = len(input_t.shape) diff --git a/tests/py/dynamo/conversion/test_slice_aten.py b/tests/py/dynamo/conversion/test_slice_aten.py index 8c0d6dae42..b332fd3354 100644 --- a/tests/py/dynamo/conversion/test_slice_aten.py +++ b/tests/py/dynamo/conversion/test_slice_aten.py @@ -7,14 +7,16 @@ from .harness import DispatchTestCase -class TestSelectConverter(DispatchTestCase): +class TestSliceConverter(DispatchTestCase): @parameterized.expand( [ - ("select_dim_start_stop_step", 0, 0, 7, 2), - ("select_dim_start_stop_step_offset", 1, 0, 7, 2), - ("select_dim_start_stop_step_exact", 1, 0, 10, 2), - ("select_dim_start_stop_step_negatives", -3, -2, -1, 1), - ("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1), + ("slice_dim_start_stop_step", 0, 0, 7, 2), + ("slice_dim_start_stop_step_offset", 1, 0, 7, 2), + ("slice_dim_start_stop_step_exact", 1, 0, 10, 2), + ("slice_dim_start_stop_step_negatives", -3, -2, -1, 1), + ("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1), + ("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1), + ("slice_dim_start_stop_step_none", 2, None, None, 1), ] ) def test_slice(self, _, dim, start, stop, step): @@ -32,12 +34,27 @@ def forward(self, input): input, ) + def test_slice_empty(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input) + return out + + input = [torch.randn(10, 10, 3, 1)] + self.run_test( + TestModule(), + input, + ) + -class TestSelectConverterDynamicShape(DispatchTestCase): +class TestSliceConverterDynamicShape(DispatchTestCase): @parameterized.expand( [ - ("select_dim_start_stop_step", 1, 0, 7, 2), - ("select_dim_start_stop_step", 1, 0, 10, 2), + ("slice_dim_start_stop_step", 1, 0, 7, 2), + ("slice_dim_start_stop_step", 1, 0, 10, 2), ] ) def test_slice(self, _, dim, start, stop, step):