Skip to content

Commit

Permalink
fix: Converter, inputs, and utils bugfixes for Transformer XL (#2404)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Oct 24, 2023
1 parent 4e5b0f6 commit 73e887c
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 63 deletions.
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def _pretraced_backend(

gm = apply_lowering_passes(gm, sample_inputs)

torchtrt_inputs = prepare_inputs(sample_inputs)
torchtrt_inputs = prepare_inputs(
sample_inputs, disable_memory_format_check=True
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
Expand Down
11 changes: 10 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@ def convert_module(
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

output_dtypes = [output.dtype for output in module_outputs]
# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
if settings.truncate_long_and_double and output.dtype == torch.float64:
output_dtypes.append(torch.float32)
elif settings.truncate_long_and_double and output.dtype == torch.int64:
output_dtypes.append(torch.int32)
else:
output_dtypes.append(output.dtype)

interpreter = TRTInterpreter(
module,
Expand Down
6 changes: 2 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ def sum(
dim: Optional[Union[int, Sequence[int]]],
keepdim: bool,
) -> TRTTensor:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
if (isinstance(input_val, TRTTensor)) and (input_val.dtype == trt.bool):
input_val = cast_trt_tensor(ctx, input_val, trt.int32, name)

if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
dim = tuple(range(len(input_val.shape)))
Expand Down
37 changes: 15 additions & 22 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,23 @@ def slice_op( # TODO: This should be slice not whatever is in base
"of the TensorRT region!"
)

ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0)
dim = get_positive_dim(dim, ranks)
dynamic_shape = has_dynamic_shape(input.shape)
if ctx.net.has_implicit_batch_dimension:
if dim == 0:
raise RuntimeError(
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
)
dim = dim - 1
else:
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
start_int = start
stop_int = stop
if stop_int == 2**63 - 1:
stop_int = input.shape[dim]
step_int = step
dim = get_positive_dim(dim, len(input.shape))
start = get_positive_dim(start, input.shape[dim])
stop = get_positive_dim(stop, input.shape[dim])

if has_dynamic_shape(input.shape):
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"

if stop == 2**63 - 1:
stop = input.shape[dim]

start_slice = [0] * len(input.shape)
start_slice[dim] = start_int
stride_slice = [1] * len(start_slice)
stride_slice[dim] = step_int
start_slice[dim] = start
stride_slice = [1] * len(input.shape)
stride_slice[dim] = step
output_shape = list(input.shape)
output_shape[dim] = math.ceil((stop_int - start_int) / step_int)
output_shape[dim] = math.ceil((stop - start) / step)

return slice(
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice
Expand Down
19 changes: 14 additions & 5 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from typing import Any, Callable, Dict, Optional, Sequence, Union

import torch
import torch_tensorrt
from torch_tensorrt._Device import Device
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._defaults import PRECISION

import torch_tensorrt
from packaging import version

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -104,25 +104,32 @@ def set_log_level(parent_logger: Any, level: Any) -> None:

def prepare_inputs(
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
disable_memory_format_check: bool = False,
) -> Any:
if isinstance(inputs, Input):
return inputs

elif isinstance(inputs, torch.Tensor):
return Input.from_tensor(inputs)
return Input.from_tensor(
inputs, disable_memory_format_check=disable_memory_format_check
)

elif isinstance(inputs, list):
torchtrt_input_list = []
for input_obj in inputs:
torchtrt_input = prepare_inputs(input_obj)
torchtrt_input = prepare_inputs(
input_obj, disable_memory_format_check=disable_memory_format_check
)
torchtrt_input_list.append(torchtrt_input)

return torchtrt_input_list

elif isinstance(inputs, tuple):
torchtrt_inputs_tup = []
for input_obj in inputs:
torchtrt_input = prepare_inputs(input_obj)
torchtrt_input = prepare_inputs(
input_obj, disable_memory_format_check=disable_memory_format_check
)
torchtrt_inputs_tup.append(torchtrt_input)

return tuple(torchtrt_inputs_tup)
Expand All @@ -131,7 +138,9 @@ def prepare_inputs(
torchtrt_inputs_dict: Dict[Any, Any] = dict()

for key, input_obj in inputs.items():
torchtrt_input = prepare_inputs(input_obj)
torchtrt_input = prepare_inputs(
input_obj, disable_memory_format_check=disable_memory_format_check
)
torchtrt_inputs_dict[key] = torchtrt_input

return torchtrt_inputs_dict
Expand Down
28 changes: 5 additions & 23 deletions tests/py/dynamo/conversion/test_slice_aten.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,20 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestSelectConverterImplicitBatch(DispatchTestCase):
class TestSelectConverter(DispatchTestCase):
@parameterized.expand(
[
("select_dim_start_stop_step", 0, 0, 7, 2),
]
)
def test_slice(self, _, dim, start, stop, step):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
return out

input = [torch.randn(10, 2, 3, 1)]
self.run_test(
TestModule(),
input,
)


class TestSelectConverterExplicitBatch(DispatchTestCase):
@parameterized.expand(
[
("select_dim_start_stop_step", 1, 0, 7, 2),
("select_dim_start_stop_step_offset", 1, 0, 7, 2),
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
("select_dim_start_stop_step_negatives", -3, -2, -1, 1),
("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
]
)
def test_slice(self, _, dim, start, stop, step):
Expand Down
16 changes: 9 additions & 7 deletions tests/py/dynamo/conversion/test_sum_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def forward(self, x):

@parameterized.expand(
[
((3, 2, 4), 1, True, torch.int, 0, 5),
((2, 3, 4, 5), None, True, torch.int, -10, 10),
((3, 2, 4), 1, True, torch.int32, 0, 5),
((2, 3, 4, 5), None, True, torch.int32, -10, 10),
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
]
Expand All @@ -85,16 +85,18 @@ def forward(self, x):
self.run_test(
Sum(),
inputs,
check_dtype=False,
output_dtypes=[torch.int32],
)

@parameterized.expand(
[
((1, 2, 4), [], True, torch.int, 0, 5),
((3, 2, 4), [1], True, torch.int, 0, 5),
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
((1, 2, 4), [], True, torch.int32, 0, 5),
((3, 2, 4), [1], True, torch.int32, 0, 5),
((2, 1, 4, 5), [0, 3], True, torch.int32, -10, 10),
((2, 3, 4, 5), None, False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.bool, 0, 2),
((4, 7, 1, 5), None, True, torch.bool, 0, 2),
]
)
def test_sum_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
Expand All @@ -106,7 +108,7 @@ def forward(self, x):
self.run_test(
Sum(),
inputs,
check_dtype=False,
output_dtypes=[torch.int32],
)


Expand Down

0 comments on commit 73e887c

Please sign in to comment.