Skip to content

Commit

Permalink
Expose IGridSampleLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Oct 12, 2023
1 parent 4e5b0f6 commit 09ffab2
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 1 deletion.
18 changes: 17 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,23 @@ def aten_ops_fmod(
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])


@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out)
def aten_ops_grid(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])


@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
ctx: ConversionContext,
target: Target,
Expand Down
26 changes: 26 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,32 @@

_LOGGER: logging.Logger = logging.getLogger(__name__)

#nearesr, linear, cubc
class GridSamplerInterpolation:
def __init__(self):
self.interpolator_mode = None
def __call__(self, interpolator_int):
if(interpolator_int == 0) :
self.interpolator_mode = trt.InterpolationMode.NEAREST
elif(interpolator_int == 1) :
self.interpolator_mode = trt.InterpolationMode.LINEAR
elif(interpolator_int == 2) :
self.interpolator_mode = trt.InterpolationMode.CUBIC
return self.interpolator_mode


#zeros, border, reflection
class GridSamplerPadding:
def __init__(self):
self.padding_mode = None
def __call__(self, padding_int):
if(padding_int == 0) :
self.padding_mode = trt.SampleMode.kFILL
elif(padding_int == 1) :
self.padding_mode = trt.SampleMode.kCLAMP
elif(padding_int == 2) :
self.padding_mode = trt.SampleMode.kREFLECT
return self.padding_mode

def get_node_name(node: torch.fx.Node) -> str:
# nn_module_stack preserves the call stack of pytorch nn.modules
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
deconv,
elementwise,
embedding,
grid,
linear,
matmul,
normalization,
Expand Down
26 changes: 26 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

def grid(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
grid: TRTTensor,
interpolation_mode: int,
padding_mode: int,
align_corners: bool,
) -> TRTTensor:
grid_layer = network.add_grid_sample(input, grid)
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
grid_layer.align_corners = align_corners
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
return grid_layer.get_output(0)
38 changes: 38 additions & 0 deletions tests/py/dynamo/conversion/test_grid_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from parameterized import parameterized
from .harness import DispatchTestCase

class TestGridConverter(DispatchTestCase):
@parameterized.expand(
[
("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
]
)
def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
class TestModule(nn.Module):
def forward(self, x):
input = torch.randn(10).reshape(input_shape)
grid = torch.randint(-1, 1, dim_shape)
return nn.functional.grid(input, grid, interpolation, sample)

inputs = [torch.randn(1, 10)]
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})







0 comments on commit 09ffab2

Please sign in to comment.