diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index bf7769c608..bb1130493e 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -1,28 +1,93 @@ import torch import torch.nn as nn +from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from .harness import DispatchTestCase -class TestIndexConverter(DispatchTestCase): - def test_index_zero_two_dim(self): - class TestModule(nn.Module): +class TestIndexConstantConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + "index_zero_two_dim_indices_input", + [None, torch.randint(0, 1, (1, 1))], + torch.randn(2, 2), + ), + ( + "index_zero_three_dim_indices_input", + [None, torch.randint(0, 1, (1, 1)), None], + torch.randn(2, 2, 2), + ), + ( + "index_zero_index_one_three_dim_indices_input", + [None, torch.randint(0, 1, (1, 1)), torch.randint(0, 1, (1, 1))], + torch.randn(2, 2, 2), + ), + ( + "index_zero_index_one_four_dim_indices_input", + [None, torch.tensor([0, 0, 1, 1]), torch.tensor([0, 0, 1, 1]), None], + torch.randn(2, 4, 4, 2), + ), + ( + "index_zero_index_one_four_dim_indices_input_SD", + [ + None, + torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]), + torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]), + None, + ], + torch.randn(2, 1280, 8, 8), + ), + ( + "index_zero_index_one_four_dim_indices_input_SD_unsqueeze", + [ + None, + torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]) + .unsqueeze(0) + .T.long(), + torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]) + .unsqueeze(0) + .T.long(), + None, + ], + torch.randn(2, 1280, 8, 8), + ), + ( + "index_zero_index_one_four_dim_indices_input_SD_unsqueeze_broadcast", + [ + None, + torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]), + torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]) + .unsqueeze(0) + .T.long(), + None, + ], + torch.randn(2, 1280, 8, 8), + ), + ( + "index_zero_index_one_four_dim_indices_input_non_continuous", + [None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])], + torch.randn(2, 4, 4, 2), + ), + ] + ) + def test_index_constant(self, _, index, input): + class TestModule(torch.nn.Module): def __init__(self): - self.index0 = torch.randint(0, 1, (1, 1)) super().__init__() - def forward(self, x): - indices = [None, self.index0] - out = torch.ops.aten.index.Tensor(x, indices) - return out + def forward(self, input): + return torch.ops.aten.index.Tensor(input, index) + + inputs = [input] + self.run_test(TestModule(), inputs) - input = [torch.randn(2, 2)] - self.run_test( - TestModule(), - input, - ) +# The below tests cannot be included in the parameterized +# [None, index0] cannot be passed as torch.Tensor to DispatchTestCase.run_test() +# tensorrt.Input requires the input to be torch Tensor +class TestIndexConverter(DispatchTestCase): def test_index_zero_two_dim_ITensor(self): class TestModule(nn.Module): def forward(self, x, index0): @@ -38,23 +103,6 @@ def forward(self, x, index0): [input, index0], ) - def test_index_zero_index_three_dim(self): - class TestModule(nn.Module): - def __init__(self): - self.index0 = torch.randint(0, 1, (1, 1)) - super().__init__() - - def forward(self, x): - indices = [None, self.index0, None] - out = torch.ops.aten.index.Tensor(x, indices) - return out - - input = [torch.randn(2, 2, 2)] - self.run_test( - TestModule(), - input, - ) - def test_index_zero_index_three_dim_ITensor(self): class TestModule(nn.Module): def forward(self, x, index0): @@ -67,122 +115,6 @@ def forward(self, x, index0): index0 = index0.to(torch.int32) self.run_test(TestModule(), [input, index0]) - def test_index_zero_index_one_index_two_three_dim(self): - class TestModule(nn.Module): - def __init__(self): - self.index0 = torch.randint(0, 1, (1, 1)) - self.index1 = torch.randint(0, 1, (1, 1)) - super().__init__() - - def forward(self, x): - indices = [None, self.index0, self.index1] - out = torch.ops.aten.index.Tensor(x, indices) - return out - - input = [torch.randn(2, 2, 2)] - self.run_test( - TestModule(), - input, - ) - - def test_index_zero_index_one_four_dim(self): - class TestModule(nn.Module): - def __init__(self): - self.index0 = torch.tensor([0, 0, 1, 1]) - self.index1 = torch.tensor([0, 0, 1, 1]) - super().__init__() - - def forward(self, x): - indices = [None, self.index0, self.index1, None] - out = torch.ops.aten.index.Tensor(x, indices) - return out - - input = [torch.randn(2, 4, 4, 2)] - self.run_test( - TestModule(), - input, - ) - - def test_index_zero_index_one_four_dim_SD(self): - class TestModule(nn.Module): - def __init__(self): - self.index0 = torch.tensor( - [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7] - ) - self.index1 = torch.tensor( - [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7] - ) - super().__init__() - - def forward(self, x): - indices = [None, self.index0, self.index1, None] - out = torch.ops.aten.index.Tensor(x, indices) - return out - - input = [torch.randn(2, 1280, 8, 8)] - self.run_test( - TestModule(), - input, - ) - - def test_index_one_SD_unsqueeze_four_dim(self): - class TestModule(nn.Module): - def __init__(self): - self.index0 = torch.tensor( - [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7] - ) - self.index1 = self.index0.unsqueeze(0).T.long() - super().__init__() - - def forward(self, x): - indices = [None, None, self.index1, self.index1] - out = torch.ops.aten.index.Tensor(x, indices) - return out - - input = [torch.randn(2, 1280, 8, 8)] - self.run_test( - TestModule(), - input, - ) - - def test_index_zero_index_one_index_two_SD_unsqueeze_four_dim_broadcast(self): - class TestModule(nn.Module): - def __init__(self): - self.index0 = torch.tensor( - [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7] - ) - self.index1 = self.index0.unsqueeze(0).T.long() - super().__init__() - - def forward(self, x): - indices = [None, None, self.index0, self.index1] - out = torch.ops.aten.index.Tensor(x, indices) - return out - - input = [torch.randn(2, 1280, 8, 8)] - self.run_test( - TestModule(), - input, - ) - - def test_index_zero_index_one_index_four_dim_non_continuous(self): - class TestModule(nn.Module): - def __init__(self): - self.index0 = torch.tensor([0, 0, 1, 1]) - self.index1 = torch.tensor([0, 0, 1, 1]) - super().__init__() - - def forward(self, x): - indices = [None, self.index0, None, self.index1] - out = torch.ops.aten.index.Tensor(x, indices) - return out - - input = [torch.randn(2, 4, 4, 2)] - self.run_test( - TestModule(), - input, - ) - if __name__ == "__main__": run_tests()