1
+ import pytest
2
+ import torch
3
+ import torch .nn as nn
4
+ from torch .testing ._internal .common_utils import run_tests
5
+ from torch_tensorrt import Input
6
+ from parameterized import parameterized
7
+ from .harness import DispatchTestCase
8
+
9
+ class TestGridConverter (DispatchTestCase ):
10
+ @parameterized .expand (
11
+ [
12
+ ("input_grid_interpolation_nearest_sample_fill" , [5 ,5 ], [5 ,2 ], 0 , 0 ),
13
+ ("input_grid_interpolation_nearest_sample_clamp" , [5 ,5 ], [5 ,2 ], 0 , 1 ),
14
+ ("input_grid_interpolation_nearest_sample_reflect" , [5 ,5 ], [5 ,2 ], 0 , 2 ),
15
+ ("input_grid_interpolation_linear_sample_fill" , [5 ,5 ], [5 ,2 ], 1 , 0 ),
16
+ ("input_grid_interpolation_linear_sample_clamp" , [5 ,5 ], [5 ,2 ], 1 , 1 ),
17
+ ("input_grid_interpolation_linear_sample_reflect" , [5 ,5 ], [5 ,2 ], 1 , 2 ),
18
+ ("input_grid_interpolation_cubic_sample_fill" , [5 ,5 ], [5 ,2 ], 2 , 0 ),
19
+ ("input_grid_interpolation_cubic_sample_clamp" , [5 ,5 ], [5 ,2 ], 2 , 1 ),
20
+ ("input_grid_interpolation_cubic_sample_reflect" , [5 ,5 ], [5 ,2 ], 2 , 2 ),
21
+ ]
22
+ )
23
+ def test_grid (self ,_ , input_shape , dim_shape , interpolation , sample ):
24
+ class TestModule (nn .Module ):
25
+ def forward (self , x ):
26
+ input = torch .randn (10 ).reshape (input_shape )
27
+ grid = torch .randint (- 1 , 1 , dim_shape )
28
+ return nn .functional .grid (input , grid , interpolation , sample )
29
+
30
+ inputs = [torch .randn (1 , 10 )]
31
+ self .run_test (TestModule (), inputs , expected_ops = {torch .ops .aten .grid_sampler .out })
32
+
33
+
34
+
35
+
36
+
37
+
38
+
0 commit comments