Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose IGridSampleLayer #2290

Merged
merged 10 commits into from
Nov 28, 2023
38 changes: 33 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,34 @@ def aten_ops_fmod(
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])


@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (TRTTensor,),
}
)
def aten_ops_grid(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.grid.grid(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
grid=args[1],
interpolation_mode=args[2],
padding_mode=args[3],
align_corners=args[4],
)


@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
ctx: ConversionContext,
Expand Down Expand Up @@ -754,12 +782,12 @@ def aten_ops_cumsum(
)


@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.tile.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_tile(
ctx: ConversionContext,
target: Target,
Expand All @@ -777,7 +805,7 @@ def aten_ops_tile(
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -1990,14 +2018,14 @@ def aten_ops_argmax(
)


@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.addmm.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (np.ndarray, torch.Tensor, TRTTensor),
2: (np.ndarray, torch.Tensor, TRTTensor),
}
) # type: ignore[misc]
)
def aten_ops_addmm(
ctx: ConversionContext,
target: Target,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
deconv,
elementwise,
embedding,
grid,
linear,
matmul,
normalization,
Expand Down
47 changes: 47 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional, Sequence

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

# nearest, linear, cubic
GridSamplerInterpolationMode = {
0: trt.InterpolationMode.NEAREST,
1: trt.InterpolationMode.LINEAR,
2: trt.InterpolationMode.CUBIC,
}

# zeros, border, reflection
GridSamplerSampling = {
0: trt.SampleMode.FILL,
1: trt.SampleMode.CLAMP,
2: trt.SampleMode.REFLECT,
}


def grid(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
grid: TRTTensor,
interpolation_mode: int,
padding_mode: int,
align_corners: bool,
) -> TRTTensor:
grid_layer = ctx.net.add_grid_sample(input, grid)
assert interpolation_mode in GridSamplerInterpolationMode
grid_layer.interpolation_mode = GridSamplerInterpolationMode.get(
interpolation_mode, None
)
assert padding_mode in GridSamplerSampling
grid_layer.sample_mode = GridSamplerSampling.get(padding_mode, None)
grid_layer.align_corners = align_corners
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
return grid_layer.get_output(0)
149 changes: 149 additions & 0 deletions tests/py/dynamo/conversion/test_grid_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import pytest
import torch
import torch.nn as nn
from .harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

grid_sampler_ops = [
(
"input_grid_interpolation_nearest_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
]


class TestGridConverter(DispatchTestCase):
@parameterized.expand(
[
(
grid_sampler_op[0],
grid_sampler_op[1],
grid_sampler_op[2],
grid_sampler_op[3],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid(self, _, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid)

inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
self.run_test(grid_model, inputs)


if __name__ == "__main__":
run_tests()
Loading