Skip to content

Commit

Permalink
fix: Bug in slice operator with default inputs
Browse files Browse the repository at this point in the history
- Slice operator schema allows optional inputs for specified dimensions
and allows not specifying certain inputs at all
- Fix issue where the converter did not reflect these Torch behaviors
- Fix issue where bounds past the end of the array were not allowed (as
they are in Python, ONNX, Torch)
- Add regression tests to catch such errors
  • Loading branch information
gs-olive committed Nov 13, 2023
1 parent 4985c70 commit 7ed104c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 18 deletions.
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,9 @@ def aten_ops_slice(
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args_bounds_check(args, 1, replacement=0),
args_bounds_check(args, 2, replacement=None),
args_bounds_check(args, 3, replacement=None),
args_bounds_check(args, 4, replacement=1),
)

Expand Down
10 changes: 5 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
import tensorrt as trt
import torch
from torch import SymBool, SymFloat, SymInt
from torch.fx.node import Argument, Target
Expand All @@ -20,8 +21,6 @@
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -339,8 +338,8 @@ def get_positive_dim(
) -> Union[int, Tuple[int, ...]]:
"""
Given an integer number or tuple that represents dimension(s) in the array,
transform it to a positive integer dim if it's negative. Otherwise, do
nothing.
transform it to a positive integer dim if it's negative.
Otherwise, truncate it to the dimension size
Args:
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
Expand All @@ -353,7 +352,8 @@ def get_positive_dim(
def positive_dim(d: int) -> int:
if d < 0:
return d % dim_size
return d
else:
return min(d, dim_size)

return (
positive_dim(dim)
Expand Down
15 changes: 10 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def slice_op( # TODO: This should be slice not whatever is in base
name: str,
input: TRTTensor,
dim: int,
start: int,
stop: int,
start: Optional[int],
stop: Optional[int],
step: int,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
Expand All @@ -37,6 +37,14 @@ def slice_op( # TODO: This should be slice not whatever is in base
"of the TensorRT region!"
)

# Special case for start being None
if start is None:
start = 0

# Special case for stop being None
if stop is None:
stop = input.shape[dim]

dim = get_positive_dim(dim, len(input.shape))
start = get_positive_dim(start, input.shape[dim])
stop = get_positive_dim(stop, input.shape[dim])
Expand All @@ -45,9 +53,6 @@ def slice_op( # TODO: This should be slice not whatever is in base
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"

if stop == 2**63 - 1:
stop = input.shape[dim]

start_slice = [0] * len(input.shape)
start_slice[dim] = start
stride_slice = [1] * len(input.shape)
Expand Down
27 changes: 22 additions & 5 deletions tests/py/dynamo/conversion/test_slice_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
class TestSelectConverter(DispatchTestCase):
@parameterized.expand(
[
("select_dim_start_stop_step", 0, 0, 7, 2),
("select_dim_start_stop_step_offset", 1, 0, 7, 2),
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
("select_dim_start_stop_step_negatives", -3, -2, -1, 1),
("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
("slice_dim_start_stop_step", 0, 0, 7, 2),
("slice_dim_start_stop_step_offset", 1, 0, 7, 2),
("slice_dim_start_stop_step_exact", 1, 0, 10, 2),
("slice_dim_start_stop_step_negatives", -3, -2, -1, 1),
("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1),
("slice_dim_start_stop_step_none", 2, None, None, 1),
]
)
def test_slice(self, _, dim, start, stop, step):
Expand All @@ -32,6 +34,21 @@ def forward(self, input):
input,
)

def test_slice_empty(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.ops.aten.slice.Tensor(input)
return out

input = [torch.randn(10, 10, 3, 1)]
self.run_test(
TestModule(),
input,
)


class TestSelectConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
Expand Down

0 comments on commit 7ed104c

Please sign in to comment.