diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9026158126..33f0de0ead 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -687,6 +687,11 @@ def aten_ops_select( @dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) def aten_ops_slice( ctx: ConversionContext, target: Target, @@ -700,9 +705,9 @@ def aten_ops_slice( SourceIR.ATEN, name, args[0], - args[1], - args[2], - args[3], + args_bounds_check(args, 1, replacement=0), + args_bounds_check(args, 2, replacement=None), + args_bounds_check(args, 3, replacement=None), args_bounds_check(args, 4, replacement=1), ) @@ -877,6 +882,11 @@ def aten_ops_clone_copy_placeholder( @dynamo_tensorrt_converter(torch.ops.aten.expand.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) def aten_ops_expand( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 724200fa2b..4f56ffbd85 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -339,8 +339,8 @@ def get_positive_dim( ) -> Union[int, Tuple[int, ...]]: """ Given an integer number or tuple that represents dimension(s) in the array, - transform it to a positive integer dim if it's negative. Otherwise, do - nothing. + transform it to a positive integer dim if it's negative. + Otherwise, truncate it to the dimension size Args: dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array. @@ -353,7 +353,8 @@ def get_positive_dim( def positive_dim(d: int) -> int: if d < 0: return d % dim_size - return d + else: + return min(d, dim_size) return ( positive_dim(dim) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 8a77508014..5c9ed2ef9c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -21,15 +21,17 @@ def slice_op( # TODO: This should be slice not whatever is in base name: str, input: TRTTensor, dim: int, - start: int, - stop: int, + start: Optional[int], + stop: Optional[int], step: int, ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"slice_tensor received input {input} that is not part " - "of the TensorRT region!" - ) + # Special case for start being None + if start is None: + start = 0 + + # Special case for stop being None + if stop is None: + stop = input.shape[dim] dim = get_positive_dim(dim, len(input.shape)) start = get_positive_dim(start, input.shape[dim]) @@ -39,9 +41,6 @@ def slice_op( # TODO: This should be slice not whatever is in base # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!" - if stop == 2**63 - 1: - stop = input.shape[dim] - start_slice = [0] * len(input.shape) start_slice[dim] = start stride_slice = [1] * len(input.shape) @@ -62,11 +61,6 @@ def expand( input_t: TRTTensor, shape: Shape, ) -> TRTTensor: - if not isinstance(input_t, TRTTensor): - raise RuntimeError( - f"expand received input {input_t} that is not a TensorRT ITensor" - ) - shape_rank = len(shape) initial_tensor_rank = len(input_t.shape) diff --git a/tests/py/dynamo/conversion/test_slice_aten.py b/tests/py/dynamo/conversion/test_slice_aten.py index 8c0d6dae42..b332fd3354 100644 --- a/tests/py/dynamo/conversion/test_slice_aten.py +++ b/tests/py/dynamo/conversion/test_slice_aten.py @@ -7,14 +7,16 @@ from .harness import DispatchTestCase -class TestSelectConverter(DispatchTestCase): +class TestSliceConverter(DispatchTestCase): @parameterized.expand( [ - ("select_dim_start_stop_step", 0, 0, 7, 2), - ("select_dim_start_stop_step_offset", 1, 0, 7, 2), - ("select_dim_start_stop_step_exact", 1, 0, 10, 2), - ("select_dim_start_stop_step_negatives", -3, -2, -1, 1), - ("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1), + ("slice_dim_start_stop_step", 0, 0, 7, 2), + ("slice_dim_start_stop_step_offset", 1, 0, 7, 2), + ("slice_dim_start_stop_step_exact", 1, 0, 10, 2), + ("slice_dim_start_stop_step_negatives", -3, -2, -1, 1), + ("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1), + ("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1), + ("slice_dim_start_stop_step_none", 2, None, None, 1), ] ) def test_slice(self, _, dim, start, stop, step): @@ -32,12 +34,27 @@ def forward(self, input): input, ) + def test_slice_empty(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input) + return out + + input = [torch.randn(10, 10, 3, 1)] + self.run_test( + TestModule(), + input, + ) + -class TestSelectConverterDynamicShape(DispatchTestCase): +class TestSliceConverterDynamicShape(DispatchTestCase): @parameterized.expand( [ - ("select_dim_start_stop_step", 1, 0, 7, 2), - ("select_dim_start_stop_step", 1, 0, 10, 2), + ("slice_dim_start_stop_step", 1, 0, 7, 2), + ("slice_dim_start_stop_step", 1, 0, 10, 2), ] ) def test_slice(self, _, dim, start, stop, step):