Skip to content

Commit

Permalink
Grid test changes
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Oct 13, 2023
1 parent 09ffab2 commit 150c643
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 49 deletions.
23 changes: 20 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,37 @@ def aten_ops_fmod(
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])


@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_grid(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
return impl.grid.grid(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
grid=args[1],
interpolation_mode=args[2],
padding_mode=args[3],
align_corners=args_bounds_check(args, 4, True),
output_mask=args_bounds_check(args, 5, None),

)


@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
Expand Down
38 changes: 21 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,36 @@

_LOGGER: logging.Logger = logging.getLogger(__name__)

#nearesr, linear, cubc

# nearest, linear, cubic
class GridSamplerInterpolation:
def __init__(self):
self.interpolator_mode = None
def __call__(self, interpolator_int):
if(interpolator_int == 0) :

def __call__(self, interpolator_int):
if interpolator_int == 0:
self.interpolator_mode = trt.InterpolationMode.NEAREST
elif(interpolator_int == 1) :
elif interpolator_int == 1:
self.interpolator_mode = trt.InterpolationMode.LINEAR
elif(interpolator_int == 2) :
elif interpolator_int == 2:
self.interpolator_mode = trt.InterpolationMode.CUBIC
return self.interpolator_mode


#zeros, border, reflection
class GridSamplerPadding:

# zeros, border, reflection
class GridSamplerSampling:
def __init__(self):
self.padding_mode = None
def __call__(self, padding_int):
if(padding_int == 0) :
self.padding_mode = trt.SampleMode.kFILL
elif(padding_int == 1) :
self.padding_mode = trt.SampleMode.kCLAMP
elif(padding_int == 2) :
self.padding_mode = trt.SampleMode.kREFLECT
return self.padding_mode
self.sample_mode = None

def __call__(self, sample_int):
if sample_int == 0:
self.sample_mode = trt.SampleMode.FILL
elif sample_int == 1:
self.sample_mode = trt.SampleMode.CLAMP
elif sample_int == 2:
self.sample_mode = trt.SampleMode.REFLECT
return self.sample_mode


def get_node_name(node: torch.fx.Node) -> str:
# nn_module_stack preserves the call stack of pytorch nn.modules
Expand Down
32 changes: 25 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/impl/grid.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from typing import Optional
from typing import Optional, Sequence

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
GridSamplerInterpolation,
GridSamplerSampling,
cast_trt_tensor,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def grid(
network: TRTNetwork,
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
Expand All @@ -17,10 +24,21 @@ def grid(
interpolation_mode: int,
padding_mode: int,
align_corners: bool,
output_mask: Optional[Sequence[bool]] = None,
) -> TRTTensor:
grid_layer = network.add_grid_sample(input, grid)
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
grid_layer = ctx.net.add_grid_sample(input, grid)
interpolation_mode_trt = GridSamplerInterpolation()
grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode)
sample_mode_trt = GridSamplerSampling()
grid_layer.sample_mode = sample_mode_trt(padding_mode)
grid_layer.align_corners = align_corners
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
return grid_layer.get_output(0)
if output_mask is None:
return grid_layer.get_output(0)
else:
if output_mask[0] and output_mask[1]:
return (grid_layer.get_output(0), None)
elif output_mask[0]:
return grid_layer.get_output(0)
else:
return None
95 changes: 73 additions & 22 deletions tests/py/dynamo/conversion/test_grid_aten.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,89 @@
import pytest
import torch
import torch.nn as nn
from harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from parameterized import parameterized
from .harness import DispatchTestCase


class TestGridConverter(DispatchTestCase):
@parameterized.expand(
[
("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
(
"input_grid_interpolation_nearest_sample_fill",
[1, 1, 5, 5],
[1, 5, 2, 2],
0,
0,
),
(
"input_grid_interpolation_nearest_sample_clamp",
[1, 1, 5, 5],
[1, 5, 2, 2],
0,
1,
),
(
"input_grid_interpolation_nearest_sample_reflect",
[1, 1, 5, 5],
[1, 5, 2, 2],
0,
2,
),
(
"input_grid_interpolation_linear_sample_fill",
[1, 1, 5, 5],
[1, 5, 2, 2],
1,
0,
),
(
"input_grid_interpolation_linear_sample_clamp",
[1, 1, 5, 5],
[1, 5, 2, 2],
1,
1,
),
(
"input_grid_interpolation_linear_sample_reflect",
[1, 1, 5, 5],
[1, 5, 2, 2],
1,
2,
),
(
"input_grid_interpolation_cubic_sample_fill",
[1, 1, 5, 5],
[1, 5, 2, 2],
2,
0,
),
(
"input_grid_interpolation_cubic_sample_clamp",
[1, 1, 5, 5],
[1, 5, 2, 2],
2,
1,
),
(
"input_grid_interpolation_cubic_sample_reflect",
[1, 1, 5, 5],
[1, 5, 2, 2],
2,
2,
),
]
)
def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
class TestModule(nn.Module):
def forward(self, x):
input = torch.randn(10).reshape(input_shape)
grid = torch.randint(-1, 1, dim_shape)
return nn.functional.grid(input, grid, interpolation, sample)

inputs = [torch.randn(1, 10)]
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})


grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True)

inputs = [torch.randn(input_shape, dtype=torch.float32)]
self.run_test(TestModule(), inputs)




if __name__ == "__main__":
run_tests()

0 comments on commit 150c643

Please sign in to comment.