Skip to content

Commit 783f7ef

Browse files
committed
Removing the misc and removing the grid_sampler.3d cases
1 parent b298710 commit 783f7ef

File tree

2 files changed

+4
-17
lines changed

2 files changed

+4
-17
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -330,16 +330,14 @@ def aten_ops_fmod(
330330
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])
331331

332332

333-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) # type: ignore[misc]
334-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) # type: ignore[misc]
335-
# commented this for now, see py/dynamo/conversion/tests/test_grid_aten. Should this be removed altogether?
336-
# @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d) # type: ignore[misc]
333+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
334+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
337335
@enforce_tensor_types(
338336
{
339337
0: (TRTTensor,),
340338
1: (TRTTensor,),
341339
}
342-
) # type: ignore[misc]
340+
)
343341
def aten_ops_grid(
344342
ctx: ConversionContext,
345343
target: Target,
@@ -360,7 +358,7 @@ def aten_ops_grid(
360358
)
361359

362360

363-
@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc]
361+
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
364362
def aten_ops_relu(
365363
ctx: ConversionContext,
366364
target: Target,

tests/py/dynamo/conversion/test_grid_aten.py

-11
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,6 @@
115115
[1, 1, 5, 5],
116116
[1, 5, 2, 2],
117117
),
118-
# The 3d cases with 4d input gives the error that it requires 5d input for both input and grid
119-
# The 5d input fails in the generation of the Grid Layer since the TensorRT layer requires 4d input
120-
# ("input_grid_interpolation_nearest_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
121-
# ("input_grid_interpolation_nearest_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
122-
# ("input_grid_interpolation_nearest_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
123-
# ("input_grid_interpolation_linear_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
124-
# ("input_grid_interpolation_linear_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
125-
# ("input_grid_interpolation_linear_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
126-
# ("input_grid_interpolation_cubic_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
127-
# ("input_grid_interpolation_cubic_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
128-
# ("input_grid_interpolation_cubic_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
129118
]
130119

131120

0 commit comments

Comments
 (0)