1
+ import pytest
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.testing._internal.common_utils import run_tests
5
+ from torch_tensorrt import Input
6
+ from parameterized import parameterized
7
+ from .harness import DispatchTestCase
8
+
9
+ class TestGridConverter(DispatchTestCase):
10
+ @parameterized.expand(
11
+ [
12
+ ("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
13
+ ("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
14
+ ("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
15
+ ("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
16
+ ("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
17
+ ("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
18
+ ("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
19
+ ("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
20
+ ("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
21
+ ]
22
+ )
23
+ def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
24
+ class TestModule(nn.Module):
25
+ def forward(self, x):
26
+ input = torch.randn(10).reshape(input_shape)
27
+ grid = torch.randint(-1, 1, dim_shape)
28
+ return nn.functional.grid(input, grid, interpolation, sample)
29
+
30
+ inputs = [torch.randn(1, 10)]
31
+ self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
32
+
33
+
34
+
35
+
36
+
37
+
38
+
0 commit comments