Skip to content

Commit

Permalink
Removing the misc and removing the grid_sampler.3d cases
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Nov 9, 2023
1 parent b298710 commit 783f7ef
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 17 deletions.
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,16 +330,14 @@ def aten_ops_fmod(
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])


@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) # type: ignore[misc]
# commented this for now, see py/dynamo/conversion/tests/test_grid_aten. Should this be removed altogether?
# @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_grid(
ctx: ConversionContext,
target: Target,
Expand All @@ -360,7 +358,7 @@ def aten_ops_grid(
)


@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
ctx: ConversionContext,
target: Target,
Expand Down
11 changes: 0 additions & 11 deletions tests/py/dynamo/conversion/test_grid_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,6 @@
[1, 1, 5, 5],
[1, 5, 2, 2],
),
# The 3d cases with 4d input gives the error that it requires 5d input for both input and grid
# The 5d input fails in the generation of the Grid Layer since the TensorRT layer requires 4d input
# ("input_grid_interpolation_nearest_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
# ("input_grid_interpolation_nearest_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
# ("input_grid_interpolation_nearest_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
# ("input_grid_interpolation_linear_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
# ("input_grid_interpolation_linear_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
# ("input_grid_interpolation_linear_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
# ("input_grid_interpolation_cubic_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
# ("input_grid_interpolation_cubic_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
# ("input_grid_interpolation_cubic_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
]


Expand Down

0 comments on commit 783f7ef

Please sign in to comment.