From b4b22c3016cd49f8a759bc8bbf25a3c8761f601e Mon Sep 17 00:00:00 2001 From: Apurba Bose <44209735+apbose@users.noreply.github.com> Date: Thu, 5 Sep 2024 17:31:06 -0700 Subject: [PATCH] chunk converter validator (#3120) --- .../dynamo/conversion/aten_ops_converters.py | 24 ---- .../dynamo/conversion/impl/slice/ops.py | 55 --------- tests/py/dynamo/conversion/test_chunk_aten.py | 105 ++++++++++++++++++ 3 files changed, 105 insertions(+), 79 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 92dfddc44f..a757cf023e 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -924,30 +924,6 @@ def aten_ops_slice( ) -@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) -@enforce_tensor_types( - { - 0: (TRTTensor,), - } -) -def aten_ops_chunk( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.slice.chunk( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - args[1], - args_bounds_check(args, 2, 0), - ) - - @dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 04eab08c47..b58435b489 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -324,61 +324,6 @@ def expand( return layer.get_output(0) -def chunk( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - chunks: int, - dim: int, -) -> TRTTensor: - if chunks <= 0: - raise RuntimeError( - f"chunk expects `chunks` to be greater than 0, got: {chunks}" - ) - - shape = input.shape - dim = get_positive_dim(dim, len(shape)) - - if dim >= len(shape): - raise RuntimeError( - f"chunk expects `dim` to be less than the length of input shape, got: {dim}" - ) - - dynamic_shape = has_dynamic_shape(input.shape) - if dynamic_shape > 0: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - - size_dim = shape[dim] - chunk_size = math.ceil(size_dim / chunks) - result = [] - start = 0 - end = min(start + chunk_size, size_dim) - cnt = 0 - - while start < end: - result.append( - slice_op( - ctx, - target, - source_ir, - f"{name}_slice_{cnt}", - input, - dim, - start, - end, - 1, - ) - ) - start = end - end = min(start + chunk_size, size_dim) - cnt += 1 - - return result - - def cumsum( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_chunk_aten.py b/tests/py/dynamo/conversion/test_chunk_aten.py index 1812165b43..eb06c04201 100644 --- a/tests/py/dynamo/conversion/test_chunk_aten.py +++ b/tests/py/dynamo/conversion/test_chunk_aten.py @@ -1,6 +1,9 @@ +import unittest + import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -27,6 +30,7 @@ def forward(self, input): self.run_test( TestChunk(), input, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -51,6 +55,7 @@ def forward(self, input): self.run_test( TestChunk(), input, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -75,6 +80,106 @@ def forward(self, input): self.run_test( TestChunk(), input, + use_dynamo_tracer=True, + ) + + +#######################Dynamic cases####################### +# The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed +@unittest.skip( + "Pending aten.split dynamic input torch.export guard bug. Issue- https://github.com/pytorch/pytorch/issues/134663" +) +class TestChunkDynamicConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1,), (1,), (3,), 3, 0), + ((3,), (3,), (4,), 3, 0), + ((4,), (4,), (6,), 3, 0), + ((6,), (6,), (9,), 3, 0), + ((3,), (3,), (4,), 1, -1), + ((3,), (3,), (4,), 3, -1), + ((3,), (3,), (4,), 4, -1), + ] + ) + def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + TestChunk(), + input_specs, + use_dynamo_tracer=True, + ) + + @parameterized.expand( + [ + ((3, 4), (3, 4), (4, 4), 1, 0), + ((3, 4), (3, 4), (4, 4), 3, 0), + ((3, 4), (3, 4), (4, 4), 4, 0), + ((3, 4), (3, 4), (4, 4), 2, -2), + ((3, 4), (3, 4), (4, 4), 6, -2), + ((3, 4), (3, 4), (4, 4), 3, 1), + ((3, 4), (3, 4), (4, 4), 4, 1), + ((3, 4), (3, 4), (4, 4), 5, -1), + ] + ) + def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + TestChunk(), + input_specs, + use_dynamo_tracer=True, + ) + + @parameterized.expand( + [ + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1), + ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1), + ] + ) + def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + TestChunk(), + input_specs, + use_dynamo_tracer=True, )