-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose IGridSampleLayer #2290
Expose IGridSampleLayer #2290
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-09-05 22:14:18.899998+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-09-05 22:16:39.641414+00:00
@@ -16,36 +16,40 @@
from .._SourceIR import SourceIR
from .converter_registry import ConverterRegistry
_LOGGER: logging.Logger = logging.getLogger(__name__)
-#nearesr, linear, cubc
+
+# nearesr, linear, cubc
class GridSamplerInterpolation:
def __init__(self):
self.interpolator_mode = None
- def __call__(self, interpolator_int):
- if(interpolator_int == 0) :
+
+ def __call__(self, interpolator_int):
+ if interpolator_int == 0:
self.interpolator_mode = trt.InterpolationMode.NEAREST
- elif(interpolator_int == 1) :
+ elif interpolator_int == 1:
self.interpolator_mode = trt.InterpolationMode.LINEAR
- elif(interpolator_int == 2) :
+ elif interpolator_int == 2:
self.interpolator_mode = trt.InterpolationMode.CUBIC
return self.interpolator_mode
-
-
-#zeros, border, reflection
+
+
+# zeros, border, reflection
class GridSamplerPadding:
def __init__(self):
self.padding_mode = None
- def __call__(self, padding_int):
- if(padding_int == 0) :
+
+ def __call__(self, padding_int):
+ if padding_int == 0:
self.padding_mode = trt.SampleMode.kFILL
- elif(padding_int == 1) :
+ elif padding_int == 1:
self.padding_mode = trt.SampleMode.kCLAMP
- elif(padding_int == 2) :
+ elif padding_int == 2:
self.padding_mode = trt.SampleMode.kREFLECT
return self.padding_mode
+
def get_node_name(node: torch.fx.Node) -> str:
# nn_module_stack preserves the call stack of pytorch nn.modules
# The call stack contains a detailed name of the module
# which shows exactly where the module is located in the
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-09-05 22:14:18.903998+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-09-05 22:16:39.839805+00:00
@@ -1,13 +1,17 @@
from typing import Optional
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
-from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+ GridSamplerInterpolation,
+ GridSamplerPadding,
+)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
+
def grid(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
@@ -21,6 +25,6 @@
grid_layer = network.add_grid_sample(input, grid)
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
grid_layer.align_corners = align_corners
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
- return grid_layer.get_output(0)
\ No newline at end of file
+ return grid_layer.get_output(0)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-09-05 22:14:18.899998+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-09-05 22:16:39.911918+00:00
@@ -163,11 +163,21 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return impl.grid.grid(network, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
+ return impl.grid.grid(
+ network,
+ target,
+ SourceIR.ATEN,
+ name,
+ args[0],
+ args[1],
+ args[2],
+ args[3],
+ args[4],
+ )
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
network: TRTNetwork,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-09-05 22:14:18.919998+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-09-05 22:16:43.370582+00:00
@@ -4,35 +4,31 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from parameterized import parameterized
from .harness import DispatchTestCase
+
class TestGridConverter(DispatchTestCase):
@parameterized.expand(
[
- ("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
- ("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
- ("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
- ("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
- ("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
- ("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
- ("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
- ("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
- ("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
+ ("input_grid_interpolation_nearest_sample_fill", [5, 5], [5, 2], 0, 0),
+ ("input_grid_interpolation_nearest_sample_clamp", [5, 5], [5, 2], 0, 1),
+ ("input_grid_interpolation_nearest_sample_reflect", [5, 5], [5, 2], 0, 2),
+ ("input_grid_interpolation_linear_sample_fill", [5, 5], [5, 2], 1, 0),
+ ("input_grid_interpolation_linear_sample_clamp", [5, 5], [5, 2], 1, 1),
+ ("input_grid_interpolation_linear_sample_reflect", [5, 5], [5, 2], 1, 2),
+ ("input_grid_interpolation_cubic_sample_fill", [5, 5], [5, 2], 2, 0),
+ ("input_grid_interpolation_cubic_sample_clamp", [5, 5], [5, 2], 2, 1),
+ ("input_grid_interpolation_cubic_sample_reflect", [5, 5], [5, 2], 2, 2),
]
)
- def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
+ def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
class TestModule(nn.Module):
def forward(self, x):
input = torch.randn(10).reshape(input_shape)
grid = torch.randint(-1, 1, dim_shape)
return nn.functional.grid(input, grid, interpolation, sample)
inputs = [torch.randn(1, 10)]
- self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
-
-
-
-
-
-
-
\ No newline at end of file
+ self.run_test(
+ TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out}
+ )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
40528e6
to
8290f63
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-06 17:56:01.641986+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-06 17:59:18.326542+00:00
@@ -21,36 +21,40 @@
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
_LOGGER: logging.Logger = logging.getLogger(__name__)
-#nearesr, linear, cubc
+
+# nearesr, linear, cubc
class GridSamplerInterpolation:
def __init__(self):
self.interpolator_mode = None
- def __call__(self, interpolator_int):
- if(interpolator_int == 0) :
+
+ def __call__(self, interpolator_int):
+ if interpolator_int == 0:
self.interpolator_mode = trt.InterpolationMode.NEAREST
- elif(interpolator_int == 1) :
+ elif interpolator_int == 1:
self.interpolator_mode = trt.InterpolationMode.LINEAR
- elif(interpolator_int == 2) :
+ elif interpolator_int == 2:
self.interpolator_mode = trt.InterpolationMode.CUBIC
return self.interpolator_mode
-
-
-#zeros, border, reflection
+
+
+# zeros, border, reflection
class GridSamplerPadding:
def __init__(self):
self.padding_mode = None
- def __call__(self, padding_int):
- if(padding_int == 0) :
+
+ def __call__(self, padding_int):
+ if padding_int == 0:
self.padding_mode = trt.SampleMode.kFILL
- elif(padding_int == 1) :
+ elif padding_int == 1:
self.padding_mode = trt.SampleMode.kCLAMP
- elif(padding_int == 2) :
+ elif padding_int == 2:
self.padding_mode = trt.SampleMode.kREFLECT
return self.padding_mode
+
def get_node_name(node: torch.fx.Node) -> str:
# nn_module_stack preserves the call stack of pytorch nn.modules
# The call stack contains a detailed name of the module
# which shows exactly where the module is located in the
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-10-06 17:56:01.641986+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-10-06 17:59:18.704188+00:00
@@ -1,13 +1,17 @@
from typing import Optional
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
-from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+ GridSamplerInterpolation,
+ GridSamplerPadding,
+)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
+
def grid(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
@@ -21,6 +25,6 @@
grid_layer = network.add_grid_sample(input, grid)
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
grid_layer.align_corners = align_corners
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
- return grid_layer.get_output(0)
\ No newline at end of file
+ return grid_layer.get_output(0)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-06 17:56:01.641986+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-06 17:59:18.888738+00:00
@@ -132,11 +132,13 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
+ return impl.grid.grid(
+ ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4]
+ )
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
ctx: ConversionContext,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-10-06 17:56:01.669986+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-10-06 17:59:24.740152+00:00
@@ -4,35 +4,31 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from parameterized import parameterized
from .harness import DispatchTestCase
+
class TestGridConverter(DispatchTestCase):
@parameterized.expand(
[
- ("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
- ("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
- ("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
- ("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
- ("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
- ("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
- ("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
- ("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
- ("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
+ ("input_grid_interpolation_nearest_sample_fill", [5, 5], [5, 2], 0, 0),
+ ("input_grid_interpolation_nearest_sample_clamp", [5, 5], [5, 2], 0, 1),
+ ("input_grid_interpolation_nearest_sample_reflect", [5, 5], [5, 2], 0, 2),
+ ("input_grid_interpolation_linear_sample_fill", [5, 5], [5, 2], 1, 0),
+ ("input_grid_interpolation_linear_sample_clamp", [5, 5], [5, 2], 1, 1),
+ ("input_grid_interpolation_linear_sample_reflect", [5, 5], [5, 2], 1, 2),
+ ("input_grid_interpolation_cubic_sample_fill", [5, 5], [5, 2], 2, 0),
+ ("input_grid_interpolation_cubic_sample_clamp", [5, 5], [5, 2], 2, 1),
+ ("input_grid_interpolation_cubic_sample_reflect", [5, 5], [5, 2], 2, 2),
]
)
- def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
+ def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
class TestModule(nn.Module):
def forward(self, x):
input = torch.randn(10).reshape(input_shape)
grid = torch.randint(-1, 1, dim_shape)
return nn.functional.grid(input, grid, interpolation, sample)
inputs = [torch.randn(1, 10)]
- self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
-
-
-
-
-
-
-
\ No newline at end of file
+ self.run_test(
+ TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out}
+ )
8290f63
to
09ffab2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-12 20:23:11.175226+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-12 20:26:13.990588+00:00
@@ -21,36 +21,40 @@
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
_LOGGER: logging.Logger = logging.getLogger(__name__)
-#nearesr, linear, cubc
+
+# nearesr, linear, cubc
class GridSamplerInterpolation:
def __init__(self):
self.interpolator_mode = None
- def __call__(self, interpolator_int):
- if(interpolator_int == 0) :
+
+ def __call__(self, interpolator_int):
+ if interpolator_int == 0:
self.interpolator_mode = trt.InterpolationMode.NEAREST
- elif(interpolator_int == 1) :
+ elif interpolator_int == 1:
self.interpolator_mode = trt.InterpolationMode.LINEAR
- elif(interpolator_int == 2) :
+ elif interpolator_int == 2:
self.interpolator_mode = trt.InterpolationMode.CUBIC
return self.interpolator_mode
-
-
-#zeros, border, reflection
+
+
+# zeros, border, reflection
class GridSamplerPadding:
def __init__(self):
self.padding_mode = None
- def __call__(self, padding_int):
- if(padding_int == 0) :
+
+ def __call__(self, padding_int):
+ if padding_int == 0:
self.padding_mode = trt.SampleMode.kFILL
- elif(padding_int == 1) :
+ elif padding_int == 1:
self.padding_mode = trt.SampleMode.kCLAMP
- elif(padding_int == 2) :
+ elif padding_int == 2:
self.padding_mode = trt.SampleMode.kREFLECT
return self.padding_mode
+
def get_node_name(node: torch.fx.Node) -> str:
# nn_module_stack preserves the call stack of pytorch nn.modules
# The call stack contains a detailed name of the module
# which shows exactly where the module is located in the
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-10-12 20:23:11.175226+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-10-12 20:26:14.342525+00:00
@@ -1,13 +1,17 @@
from typing import Optional
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
-from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+ GridSamplerInterpolation,
+ GridSamplerPadding,
+)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
+
def grid(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
@@ -21,6 +25,6 @@
grid_layer = network.add_grid_sample(input, grid)
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
grid_layer.align_corners = align_corners
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
- return grid_layer.get_output(0)
\ No newline at end of file
+ return grid_layer.get_output(0)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-12 20:23:11.175226+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-12 20:26:14.706447+00:00
@@ -256,11 +256,13 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
+ return impl.grid.grid(
+ ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4]
+ )
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
ctx: ConversionContext,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-10-12 20:23:11.199228+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-10-12 20:26:19.877166+00:00
@@ -4,35 +4,31 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from parameterized import parameterized
from .harness import DispatchTestCase
+
class TestGridConverter(DispatchTestCase):
@parameterized.expand(
[
- ("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
- ("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
- ("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
- ("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
- ("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
- ("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
- ("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
- ("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
- ("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
+ ("input_grid_interpolation_nearest_sample_fill", [5, 5], [5, 2], 0, 0),
+ ("input_grid_interpolation_nearest_sample_clamp", [5, 5], [5, 2], 0, 1),
+ ("input_grid_interpolation_nearest_sample_reflect", [5, 5], [5, 2], 0, 2),
+ ("input_grid_interpolation_linear_sample_fill", [5, 5], [5, 2], 1, 0),
+ ("input_grid_interpolation_linear_sample_clamp", [5, 5], [5, 2], 1, 1),
+ ("input_grid_interpolation_linear_sample_reflect", [5, 5], [5, 2], 1, 2),
+ ("input_grid_interpolation_cubic_sample_fill", [5, 5], [5, 2], 2, 0),
+ ("input_grid_interpolation_cubic_sample_clamp", [5, 5], [5, 2], 2, 1),
+ ("input_grid_interpolation_cubic_sample_reflect", [5, 5], [5, 2], 2, 2),
]
)
- def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
+ def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
class TestModule(nn.Module):
def forward(self, x):
input = torch.randn(10).reshape(input_shape)
grid = torch.randint(-1, 1, dim_shape)
return nn.functional.grid(input, grid, interpolation, sample)
inputs = [torch.randn(1, 10)]
- self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
-
-
-
-
-
-
-
\ No newline at end of file
+ self.run_test(
+ TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out}
+ )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-13 00:17:54.909933+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-13 00:20:33.526856+00:00
@@ -21,36 +21,40 @@
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
_LOGGER: logging.Logger = logging.getLogger(__name__)
-#nearesr, linear, cubc
+
+# nearesr, linear, cubc
class GridSamplerInterpolation:
def __init__(self):
self.interpolator_mode = None
- def __call__(self, interpolator_int):
- if(interpolator_int == 0) :
+
+ def __call__(self, interpolator_int):
+ if interpolator_int == 0:
self.interpolator_mode = trt.InterpolationMode.NEAREST
- elif(interpolator_int == 1) :
+ elif interpolator_int == 1:
self.interpolator_mode = trt.InterpolationMode.LINEAR
- elif(interpolator_int == 2) :
+ elif interpolator_int == 2:
self.interpolator_mode = trt.InterpolationMode.CUBIC
return self.interpolator_mode
-
-
-#zeros, border, reflection
+
+
+# zeros, border, reflection
class GridSamplerPadding:
def __init__(self):
self.padding_mode = None
- def __call__(self, padding_int):
- if(padding_int == 0) :
+
+ def __call__(self, padding_int):
+ if padding_int == 0:
self.padding_mode = trt.SampleMode.kFILL
- elif(padding_int == 1) :
+ elif padding_int == 1:
self.padding_mode = trt.SampleMode.kCLAMP
- elif(padding_int == 2) :
+ elif padding_int == 2:
self.padding_mode = trt.SampleMode.kREFLECT
return self.padding_mode
+
def get_node_name(node: torch.fx.Node) -> str:
# nn_module_stack preserves the call stack of pytorch nn.modules
# The call stack contains a detailed name of the module
# which shows exactly where the module is located in the
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-10-13 00:17:54.909933+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-10-13 00:20:33.917971+00:00
@@ -3,13 +3,18 @@
import torch
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor, GridSamplerInterpolation, GridSamplerSampling
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+ cast_trt_tensor,
+ GridSamplerInterpolation,
+ GridSamplerSampling,
+)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
+
def grid(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
@@ -20,20 +25,20 @@
padding_mode: int,
align_corners: bool,
output_mask: Optional[Sequence[bool]] = None,
) -> TRTTensor:
grid_layer = ctx.net.add_grid_sample(input, grid)
- interpolation_mode_trt = GridSamplerInterpolation()
+ interpolation_mode_trt = GridSamplerInterpolation()
grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode)
sample_mode_trt = GridSamplerSampling()
- grid_layer.sample_mode = sample_mode_trt(padding_mode)
+ grid_layer.sample_mode = sample_mode_trt(padding_mode)
grid_layer.align_corners = align_corners
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
- if(output_mask is None):
+ if output_mask is None:
return grid_layer.get_output(0)
else:
- if(output_mask[0] and output_mask[1]):
+ if output_mask[0] and output_mask[1]:
return (grid_layer.get_output(0), None)
- elif(output_mask[0]):
+ elif output_mask[0]:
return grid_layer.get_output(0)
else:
- return None
\ No newline at end of file
+ return None
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-13 00:17:54.909933+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-13 00:20:34.077902+00:00
@@ -262,21 +262,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.grid.grid(
- ctx,
- target,
- SourceIR.ATEN,
- name,
- input=args[0],
- grid=args[1],
- interpolation_mode=args[2],
- padding_mode=args[3],
+ ctx,
+ target,
+ SourceIR.ATEN,
+ name,
+ input=args[0],
+ grid=args[1],
+ interpolation_mode=args[2],
+ padding_mode=args[3],
align_corners=args_bounds_check(args, 4, True),
output_mask=args_bounds_check(args, 5, None),
-
)
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-10-13 00:17:54.929933+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-10-13 00:20:38.093945+00:00
@@ -4,35 +4,86 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from parameterized import parameterized
from harness import DispatchTestCase
+
class TestGridConverter(DispatchTestCase):
@parameterized.expand(
[
- ("input_grid_interpolation_nearest_sample_fill", [1,1,5,5], [1,5,2,2], 0, 0),
- ("input_grid_interpolation_nearest_sample_clamp", [1,1,5,5], [1,5,2,2], 0, 1),
- ("input_grid_interpolation_nearest_sample_reflect", [1,1,5,5], [1,5,2,2], 0, 2),
- ("input_grid_interpolation_linear_sample_fill", [1,1,5,5], [1,5,2,2], 1, 0),
- ("input_grid_interpolation_linear_sample_clamp", [1,1,5,5], [1,5,2,2], 1, 1),
- ("input_grid_interpolation_linear_sample_reflect", [1,1,5,5], [1,5,2,2], 1, 2),
- ("input_grid_interpolation_cubic_sample_fill", [1,1,5,5], [1,5,2,2], 2, 0),
- ("input_grid_interpolation_cubic_sample_clamp", [1,1,5,5], [1,5,2,2], 2, 1),
- ("input_grid_interpolation_cubic_sample_reflect", [1,1,5,5], [1,5,2,2], 2, 2),
+ (
+ "input_grid_interpolation_nearest_sample_fill",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 0,
+ 0,
+ ),
+ (
+ "input_grid_interpolation_nearest_sample_clamp",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 0,
+ 1,
+ ),
+ (
+ "input_grid_interpolation_nearest_sample_reflect",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 0,
+ 2,
+ ),
+ (
+ "input_grid_interpolation_linear_sample_fill",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 1,
+ 0,
+ ),
+ (
+ "input_grid_interpolation_linear_sample_clamp",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 1,
+ 1,
+ ),
+ (
+ "input_grid_interpolation_linear_sample_reflect",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 1,
+ 2,
+ ),
+ (
+ "input_grid_interpolation_cubic_sample_fill",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 2,
+ 0,
+ ),
+ (
+ "input_grid_interpolation_cubic_sample_clamp",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 2,
+ 1,
+ ),
+ (
+ "input_grid_interpolation_cubic_sample_reflect",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 2,
+ 2,
+ ),
]
)
def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
class TestModule(nn.Module):
- def forward(self, x):
+ def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True)
- inputs = [torch.randn(input_shape, dtype = torch.float32)]
+
+ inputs = [torch.randn(input_shape, dtype=torch.float32)]
self.run_test(TestModule(), inputs)
+
if __name__ == "__main__":
run_tests()
-
-
-
-
-
-
\ No newline at end of file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-13 00:22:45.467171+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-13 00:25:24.662061+00:00
@@ -21,36 +21,40 @@
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
_LOGGER: logging.Logger = logging.getLogger(__name__)
-#nearesr, linear, cubc
+
+# nearesr, linear, cubc
class GridSamplerInterpolation:
def __init__(self):
self.interpolator_mode = None
- def __call__(self, interpolator_int):
- if(interpolator_int == 0) :
+
+ def __call__(self, interpolator_int):
+ if interpolator_int == 0:
self.interpolator_mode = trt.InterpolationMode.NEAREST
- elif(interpolator_int == 1) :
+ elif interpolator_int == 1:
self.interpolator_mode = trt.InterpolationMode.LINEAR
- elif(interpolator_int == 2) :
+ elif interpolator_int == 2:
self.interpolator_mode = trt.InterpolationMode.CUBIC
return self.interpolator_mode
-
-
-#zeros, border, reflection
+
+
+# zeros, border, reflection
class GridSamplerPadding:
def __init__(self):
self.padding_mode = None
- def __call__(self, padding_int):
- if(padding_int == 0) :
+
+ def __call__(self, padding_int):
+ if padding_int == 0:
self.padding_mode = trt.SampleMode.kFILL
- elif(padding_int == 1) :
+ elif padding_int == 1:
self.padding_mode = trt.SampleMode.kCLAMP
- elif(padding_int == 2) :
+ elif padding_int == 2:
self.padding_mode = trt.SampleMode.kREFLECT
return self.padding_mode
+
def get_node_name(node: torch.fx.Node) -> str:
# nn_module_stack preserves the call stack of pytorch nn.modules
# The call stack contains a detailed name of the module
# which shows exactly where the module is located in the
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-10-13 00:22:45.467171+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/grid.py 2023-10-13 00:25:25.012634+00:00
@@ -3,13 +3,18 @@
import torch
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor, GridSamplerInterpolation, GridSamplerSampling
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+ cast_trt_tensor,
+ GridSamplerInterpolation,
+ GridSamplerSampling,
+)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
+
def grid(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
@@ -20,20 +25,20 @@
padding_mode: int,
align_corners: bool,
output_mask: Optional[Sequence[bool]] = None,
) -> TRTTensor:
grid_layer = ctx.net.add_grid_sample(input, grid)
- interpolation_mode_trt = GridSamplerInterpolation()
+ interpolation_mode_trt = GridSamplerInterpolation()
grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode)
sample_mode_trt = GridSamplerSampling()
- grid_layer.sample_mode = sample_mode_trt(padding_mode)
+ grid_layer.sample_mode = sample_mode_trt(padding_mode)
grid_layer.align_corners = align_corners
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
- if(output_mask is None):
+ if output_mask is None:
return grid_layer.get_output(0)
else:
- if(output_mask[0] and output_mask[1]):
+ if output_mask[0] and output_mask[1]:
return (grid_layer.get_output(0), None)
- elif(output_mask[0]):
+ elif output_mask[0]:
return grid_layer.get_output(0)
else:
- return None
\ No newline at end of file
+ return None
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-13 00:22:45.467171+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-13 00:25:25.209460+00:00
@@ -262,21 +262,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.grid.grid(
- ctx,
- target,
- SourceIR.ATEN,
- name,
- input=args[0],
- grid=args[1],
- interpolation_mode=args[2],
- padding_mode=args[3],
+ ctx,
+ target,
+ SourceIR.ATEN,
+ name,
+ input=args[0],
+ grid=args[1],
+ interpolation_mode=args[2],
+ padding_mode=args[3],
align_corners=args_bounds_check(args, 4, True),
output_mask=args_bounds_check(args, 5, None),
-
)
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-10-13 00:22:45.487171+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_grid_aten.py 2023-10-13 00:25:28.842495+00:00
@@ -4,35 +4,86 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from parameterized import parameterized
from harness import DispatchTestCase
+
class TestGridConverter(DispatchTestCase):
@parameterized.expand(
[
- ("input_grid_interpolation_nearest_sample_fill", [1,1,5,5], [1,5,2,2], 0, 0),
- ("input_grid_interpolation_nearest_sample_clamp", [1,1,5,5], [1,5,2,2], 0, 1),
- ("input_grid_interpolation_nearest_sample_reflect", [1,1,5,5], [1,5,2,2], 0, 2),
- ("input_grid_interpolation_linear_sample_fill", [1,1,5,5], [1,5,2,2], 1, 0),
- ("input_grid_interpolation_linear_sample_clamp", [1,1,5,5], [1,5,2,2], 1, 1),
- ("input_grid_interpolation_linear_sample_reflect", [1,1,5,5], [1,5,2,2], 1, 2),
- ("input_grid_interpolation_cubic_sample_fill", [1,1,5,5], [1,5,2,2], 2, 0),
- ("input_grid_interpolation_cubic_sample_clamp", [1,1,5,5], [1,5,2,2], 2, 1),
- ("input_grid_interpolation_cubic_sample_reflect", [1,1,5,5], [1,5,2,2], 2, 2),
+ (
+ "input_grid_interpolation_nearest_sample_fill",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 0,
+ 0,
+ ),
+ (
+ "input_grid_interpolation_nearest_sample_clamp",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 0,
+ 1,
+ ),
+ (
+ "input_grid_interpolation_nearest_sample_reflect",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 0,
+ 2,
+ ),
+ (
+ "input_grid_interpolation_linear_sample_fill",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 1,
+ 0,
+ ),
+ (
+ "input_grid_interpolation_linear_sample_clamp",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 1,
+ 1,
+ ),
+ (
+ "input_grid_interpolation_linear_sample_reflect",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 1,
+ 2,
+ ),
+ (
+ "input_grid_interpolation_cubic_sample_fill",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 2,
+ 0,
+ ),
+ (
+ "input_grid_interpolation_cubic_sample_clamp",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 2,
+ 1,
+ ),
+ (
+ "input_grid_interpolation_cubic_sample_reflect",
+ [1, 1, 5, 5],
+ [1, 5, 2, 2],
+ 2,
+ 2,
+ ),
]
)
def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
class TestModule(nn.Module):
- def forward(self, x):
+ def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True)
- inputs = [torch.randn(input_shape, dtype = torch.float32)]
+
+ inputs = [torch.randn(input_shape, dtype=torch.float32)]
self.run_test(TestModule(), inputs)
+
if __name__ == "__main__":
run_tests()
-
-
-
-
-
-
\ No newline at end of file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
73f1158
to
150c643
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-13 00:25:52.876877+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-13 00:28:28.778380+00:00
@@ -262,21 +262,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.grid.grid(
- ctx,
- target,
- SourceIR.ATEN,
- name,
- input=args[0],
- grid=args[1],
- interpolation_mode=args[2],
- padding_mode=args[3],
+ ctx,
+ target,
+ SourceIR.ATEN,
+ name,
+ input=args[0],
+ grid=args[1],
+ interpolation_mode=args[2],
+ padding_mode=args[3],
align_corners=args_bounds_check(args, 4, True),
output_mask=args_bounds_check(args, 5, None),
-
)
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
150c643
to
ac7c95c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-13 00:51:36.892575+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-13 00:54:31.161981+00:00
@@ -262,21 +262,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.grid.grid(
- ctx,
- target,
- SourceIR.ATEN,
- name,
- input=args[0],
- grid=args[1],
- interpolation_mode=args[2],
- padding_mode=args[3],
+ ctx,
+ target,
+ SourceIR.ATEN,
+ name,
+ input=args[0],
+ grid=args[1],
+ interpolation_mode=args[2],
+ padding_mode=args[3],
align_corners=args_bounds_check(args, 4, True),
output_mask=args_bounds_check(args, 5, None),
-
)
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) | ||
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out) | ||
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out) | ||
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out) | ||
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add # type: ignore[misc]
to these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove _backward
implementations. Consider changing the decorator stack to:
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.default) # type: ignore[misc]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add the # type: ignore[misc]
. I read online that this indicates that mypy
would ignore this. So how do we know where this is to be mentioned?
Also regarding the .out
, I see that the target in the test is torch.ops.aten.grid_sampler
, so should default be mentioned then? Because just mentioning
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.default) # type: ignore[misc]
Shows the following error
py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 335,
in call_function
raise UnsupportedOperatorException(
torch_tensorrt.dynamo.conversion._TRTInterpreter.UnsupportedOperatorException: Conversion of function torch._ops.aten.PyCapsule.grid_sampler not
currently supported!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the # type: ignore[misc]
, generally mypy
linting will indicate an error for untyped decorators. Our dynamo_tensorrt_converter
is not untyped, but for some reason the error persists. For now, we just keep the ignore
on all decorators in this file.
On the second issue, I think you can remove the .out
variant, but change the @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default)
back to just @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
. That should address the error
interpolation_mode=args[2], | ||
padding_mode=args[3], | ||
align_corners=args_bounds_check(args, 4, True), | ||
output_mask=args_bounds_check(args, 5, None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed once the out
variants are removed from the decorator list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the output_mask
argument? As far as I understood, the .backward
would have the output_mask
argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I meant the output_mask
argument. I don't think we should ever encounter an argument of the backward
variety since the backend is strictly forward/inference-only
# nearest, linear, cubic | ||
class GridSamplerInterpolation: | ||
def __init__(self): | ||
self.interpolator_mode = None | ||
|
||
def __call__(self, interpolator_int): | ||
if interpolator_int == 0: | ||
self.interpolator_mode = trt.InterpolationMode.NEAREST | ||
elif interpolator_int == 1: | ||
self.interpolator_mode = trt.InterpolationMode.LINEAR | ||
elif interpolator_int == 2: | ||
self.interpolator_mode = trt.InterpolationMode.CUBIC | ||
return self.interpolator_mode | ||
|
||
|
||
# zeros, border, reflection | ||
class GridSamplerSampling: | ||
def __init__(self): | ||
self.sample_mode = None | ||
|
||
def __call__(self, sample_int): | ||
if sample_int == 0: | ||
self.sample_mode = trt.SampleMode.FILL | ||
elif sample_int == 1: | ||
self.sample_mode = trt.SampleMode.CLAMP | ||
elif sample_int == 2: | ||
self.sample_mode = trt.SampleMode.REFLECT | ||
return self.sample_mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could just go in the grid.py
file, since it is very specific to the grid operations.
grid_layer = ctx.net.add_grid_sample(input, grid) | ||
interpolation_mode_trt = GridSamplerInterpolation() | ||
grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode) | ||
sample_mode_trt = GridSamplerSampling() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using a dictionary for this - could be grid_layer.interpolation_mode = Dictionary.get(interpolation_mode, None)
if output_mask[0] and output_mask[1]: | ||
return (grid_layer.get_output(0), None) | ||
elif output_mask[0]: | ||
return grid_layer.get_output(0) | ||
else: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bring up one level and make all of these elif
statements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-20 00:36:43.890962+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-10-20 00:39:35.534806+00:00
@@ -262,21 +262,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.grid.grid(
- ctx,
- target,
- SourceIR.ATEN,
- name,
- input=args[0],
- grid=args[1],
- interpolation_mode=args[2],
- padding_mode=args[3],
+ ctx,
+ target,
+ SourceIR.ATEN,
+ name,
+ input=args[0],
+ grid=args[1],
+ interpolation_mode=args[2],
+ padding_mode=args[3],
align_corners=args_bounds_check(args, 4, True),
output_mask=args_bounds_check(args, 5, None),
-
)
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
908d1fe
to
783f7ef
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Fixes #2202