Skip to content

Commit

Permalink
feat: support cumsum dynamo converter (#2403)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Oct 31, 2023
1 parent b5efb6e commit f617898
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 1 deletion.
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,29 @@ def aten_ops_chunk(
)


@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_cumsum(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.cumsum(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
@enforce_tensor_types(
{
Expand Down
48 changes: 47 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import math
from typing import Optional

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
Expand Down Expand Up @@ -157,3 +163,43 @@ def chunk(
cnt += 1

return result


def cumsum(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
) -> TRTTensor:
input_shape = input.shape
dim = get_positive_dim(dim, len(input_shape))
loop = ctx.net.add_loop()
axis = np.array(input_shape[dim])
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
iterator = loop.add_iterator(input, dim, reverse=False)
data = iterator.get_output(0)
new_dims = tuple(data.shape)
zeros = np.zeros(new_dims)
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")

running_sum = loop.add_recurrence(zero_trttensor)
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
running_sum_tensor = running_sum.get_output(0)

current_sum = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_elementwise_add",
data,
running_sum_tensor,
)
running_sum.set_input(1, current_sum)

loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim)
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
loop_output.set_input(1, trip_limit)
return loop_output.get_output(0)
69 changes: 69 additions & 0 deletions tests/py/dynamo/conversion/test_cumsum_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestCumsumConverter(DispatchTestCase):
@parameterized.expand(
[
((1,), 0),
((2,), 0),
((3,), -1),
]
)
def test_cumsum_1D(self, shape, dim):
class Cumsum(nn.Module):
def forward(self, x):
return torch.ops.aten.cumsum.default(x, dim)

inputs = [torch.randn(shape)]
self.run_test(
Cumsum(),
inputs,
)

@parameterized.expand(
[
((3, 1), 0),
((3, 1), 1),
((2, 3), -1),
((2, 3), -2),
]
)
def test_cumsum_2D(self, shape, dims):
class Cumsum(nn.Module):
def forward(self, x):
return torch.ops.aten.cumsum.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Cumsum(),
inputs,
)

@parameterized.expand(
[
((4, 2, 3), 0),
((4, 2, 3), 1),
((1, 2, 3), 2),
((1, 2, 3), -1),
((1, 2, 3), -2),
]
)
def test_cumsum_3D(self, shape, dims):
class Cumsum(nn.Module):
def forward(self, x):
return torch.ops.aten.cumsum.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Cumsum(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit f617898

Please sign in to comment.