Skip to content

Commit

Permalink
chunk converter validator (#3120)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored Sep 6, 2024
1 parent 29b4913 commit b4b22c3
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 79 deletions.
24 changes: 0 additions & 24 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,30 +924,6 @@ def aten_ops_slice(
)


@dynamo_tensorrt_converter(torch.ops.aten.chunk.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_chunk(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.chunk(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args_bounds_check(args, 2, 0),
)


@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
Expand Down
55 changes: 0 additions & 55 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,61 +324,6 @@ def expand(
return layer.get_output(0)


def chunk(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
chunks: int,
dim: int,
) -> TRTTensor:
if chunks <= 0:
raise RuntimeError(
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
)

shape = input.shape
dim = get_positive_dim(dim, len(shape))

if dim >= len(shape):
raise RuntimeError(
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
)

dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape > 0:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"

size_dim = shape[dim]
chunk_size = math.ceil(size_dim / chunks)
result = []
start = 0
end = min(start + chunk_size, size_dim)
cnt = 0

while start < end:
result.append(
slice_op(
ctx,
target,
source_ir,
f"{name}_slice_{cnt}",
input,
dim,
start,
end,
1,
)
)
start = end
end = min(start + chunk_size, size_dim)
cnt += 1

return result


def cumsum(
ctx: ConversionContext,
target: Target,
Expand Down
105 changes: 105 additions & 0 deletions tests/py/dynamo/conversion/test_chunk_aten.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import unittest

import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -27,6 +30,7 @@ def forward(self, input):
self.run_test(
TestChunk(),
input,
use_dynamo_tracer=True,
)

@parameterized.expand(
Expand All @@ -51,6 +55,7 @@ def forward(self, input):
self.run_test(
TestChunk(),
input,
use_dynamo_tracer=True,
)

@parameterized.expand(
Expand All @@ -75,6 +80,106 @@ def forward(self, input):
self.run_test(
TestChunk(),
input,
use_dynamo_tracer=True,
)


#######################Dynamic cases#######################
# The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
@unittest.skip(
"Pending aten.split dynamic input torch.export guard bug. Issue- https://github.com/pytorch/pytorch/issues/134663"
)
class TestChunkDynamicConverter(DispatchTestCase):
@parameterized.expand(
[
((1,), (1,), (3,), 3, 0),
((3,), (3,), (4,), 3, 0),
((4,), (4,), (6,), 3, 0),
((6,), (6,), (9,), 3, 0),
((3,), (3,), (4,), 1, -1),
((3,), (3,), (4,), 3, -1),
((3,), (3,), (4,), 4, -1),
]
)
def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
TestChunk(),
input_specs,
use_dynamo_tracer=True,
)

@parameterized.expand(
[
((3, 4), (3, 4), (4, 4), 1, 0),
((3, 4), (3, 4), (4, 4), 3, 0),
((3, 4), (3, 4), (4, 4), 4, 0),
((3, 4), (3, 4), (4, 4), 2, -2),
((3, 4), (3, 4), (4, 4), 6, -2),
((3, 4), (3, 4), (4, 4), 3, 1),
((3, 4), (3, 4), (4, 4), 4, 1),
((3, 4), (3, 4), (4, 4), 5, -1),
]
)
def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
TestChunk(),
input_specs,
use_dynamo_tracer=True,
)

@parameterized.expand(
[
((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1),
((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1),
]
)
def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
TestChunk(),
input_specs,
use_dynamo_tracer=True,
)


Expand Down

0 comments on commit b4b22c3

Please sign in to comment.