Skip to content

Commit 09ffab2

Browse files
committed
Expose IGridSampleLayer
1 parent 4e5b0f6 commit 09ffab2

File tree

5 files changed

+108
-1
lines changed

5 files changed

+108
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,23 @@ def aten_ops_fmod(
245245
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])
246246

247247

248-
@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc]
248+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out)
249+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out)
250+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out)
251+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out)
252+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out)
253+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out)
254+
def aten_ops_grid(
255+
ctx: ConversionContext,
256+
target: Target,
257+
args: Tuple[Argument, ...],
258+
kwargs: Dict[str, Argument],
259+
name: str,
260+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
261+
return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
262+
263+
264+
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
249265
def aten_ops_relu(
250266
ctx: ConversionContext,
251267
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+26
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,32 @@
2323

2424
_LOGGER: logging.Logger = logging.getLogger(__name__)
2525

26+
#nearesr, linear, cubc
27+
class GridSamplerInterpolation:
28+
def __init__(self):
29+
self.interpolator_mode = None
30+
def __call__(self, interpolator_int):
31+
if(interpolator_int == 0) :
32+
self.interpolator_mode = trt.InterpolationMode.NEAREST
33+
elif(interpolator_int == 1) :
34+
self.interpolator_mode = trt.InterpolationMode.LINEAR
35+
elif(interpolator_int == 2) :
36+
self.interpolator_mode = trt.InterpolationMode.CUBIC
37+
return self.interpolator_mode
38+
39+
40+
#zeros, border, reflection
41+
class GridSamplerPadding:
42+
def __init__(self):
43+
self.padding_mode = None
44+
def __call__(self, padding_int):
45+
if(padding_int == 0) :
46+
self.padding_mode = trt.SampleMode.kFILL
47+
elif(padding_int == 1) :
48+
self.padding_mode = trt.SampleMode.kCLAMP
49+
elif(padding_int == 2) :
50+
self.padding_mode = trt.SampleMode.kREFLECT
51+
return self.padding_mode
2652

2753
def get_node_name(node: torch.fx.Node) -> str:
2854
# nn_module_stack preserves the call stack of pytorch nn.modules

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
deconv,
1212
elementwise,
1313
embedding,
14+
grid,
1415
linear,
1516
matmul,
1617
normalization,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
10+
def grid(
11+
network: TRTNetwork,
12+
target: Target,
13+
source_ir: Optional[SourceIR],
14+
name: str,
15+
input: TRTTensor,
16+
grid: TRTTensor,
17+
interpolation_mode: int,
18+
padding_mode: int,
19+
align_corners: bool,
20+
) -> TRTTensor:
21+
grid_layer = network.add_grid_sample(input, grid)
22+
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
23+
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
24+
grid_layer.align_corners = align_corners
25+
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
26+
return grid_layer.get_output(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
from parameterized import parameterized
7+
from .harness import DispatchTestCase
8+
9+
class TestGridConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
13+
("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
14+
("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
15+
("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
16+
("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
17+
("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
18+
("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
19+
("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
20+
("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
21+
]
22+
)
23+
def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
24+
class TestModule(nn.Module):
25+
def forward(self, x):
26+
input = torch.randn(10).reshape(input_shape)
27+
grid = torch.randint(-1, 1, dim_shape)
28+
return nn.functional.grid(input, grid, interpolation, sample)
29+
30+
inputs = [torch.randn(1, 10)]
31+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
32+
33+
34+
35+
36+
37+
38+

0 commit comments

Comments
 (0)