Skip to content

Commit

Permalink
Dynamic shape index
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Jul 25, 2024
1 parent abf3370 commit 059d09a
Showing 1 changed file with 78 additions and 146 deletions.
224 changes: 78 additions & 146 deletions tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,93 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestIndexConverter(DispatchTestCase):
def test_index_zero_two_dim(self):
class TestModule(nn.Module):
class TestIndexConstantConverter(DispatchTestCase):
@parameterized.expand(
[
(
"index_zero_two_dim_indices_input",
[None, torch.randint(0, 1, (1, 1))],
torch.randn(2, 2),
),
(
"index_zero_three_dim_indices_input",
[None, torch.randint(0, 1, (1, 1)), None],
torch.randn(2, 2, 2),
),
(
"index_zero_index_one_three_dim_indices_input",
[None, torch.randint(0, 1, (1, 1)), torch.randint(0, 1, (1, 1))],
torch.randn(2, 2, 2),
),
(
"index_zero_index_one_four_dim_indices_input",
[None, torch.tensor([0, 0, 1, 1]), torch.tensor([0, 0, 1, 1]), None],
torch.randn(2, 4, 4, 2),
),
(
"index_zero_index_one_four_dim_indices_input_SD",
[
None,
torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]),
torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]),
None,
],
torch.randn(2, 1280, 8, 8),
),
(
"index_zero_index_one_four_dim_indices_input_SD_unsqueeze",
[
None,
torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
.unsqueeze(0)
.T.long(),
torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
.unsqueeze(0)
.T.long(),
None,
],
torch.randn(2, 1280, 8, 8),
),
(
"index_zero_index_one_four_dim_indices_input_SD_unsqueeze_broadcast",
[
None,
torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]),
torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
.unsqueeze(0)
.T.long(),
None,
],
torch.randn(2, 1280, 8, 8),
),
(
"index_zero_index_one_four_dim_indices_input_non_continuous",
[None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])],
torch.randn(2, 4, 4, 2),
),
]
)
def test_index_constant(self, _, index, input):
class TestModule(torch.nn.Module):
def __init__(self):
self.index0 = torch.randint(0, 1, (1, 1))
super().__init__()

def forward(self, x):
indices = [None, self.index0]
out = torch.ops.aten.index.Tensor(x, indices)
return out
def forward(self, input):
return torch.ops.aten.index.Tensor(input, index)

inputs = [input]
self.run_test(TestModule(), inputs)

input = [torch.randn(2, 2)]
self.run_test(
TestModule(),
input,
)

# The below tests cannot be included in the parameterized
# [None, index0] cannot be passed as torch.Tensor to DispatchTestCase.run_test()
# tensorrt.Input requires the input to be torch Tensor
class TestIndexConverter(DispatchTestCase):
def test_index_zero_two_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
Expand All @@ -38,23 +103,6 @@ def forward(self, x, index0):
[input, index0],
)

def test_index_zero_index_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.randint(0, 1, (1, 1))
super().__init__()

def forward(self, x):
indices = [None, self.index0, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = [torch.randn(2, 2, 2)]
self.run_test(
TestModule(),
input,
)

def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
Expand All @@ -67,122 +115,6 @@ def forward(self, x, index0):
index0 = index0.to(torch.int32)
self.run_test(TestModule(), [input, index0])

def test_index_zero_index_one_index_two_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.randint(0, 1, (1, 1))
self.index1 = torch.randint(0, 1, (1, 1))
super().__init__()

def forward(self, x):
indices = [None, self.index0, self.index1]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = [torch.randn(2, 2, 2)]
self.run_test(
TestModule(),
input,
)

def test_index_zero_index_one_four_dim(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.tensor([0, 0, 1, 1])
self.index1 = torch.tensor([0, 0, 1, 1])
super().__init__()

def forward(self, x):
indices = [None, self.index0, self.index1, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = [torch.randn(2, 4, 4, 2)]
self.run_test(
TestModule(),
input,
)

def test_index_zero_index_one_four_dim_SD(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.tensor(
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
)
self.index1 = torch.tensor(
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
)
super().__init__()

def forward(self, x):
indices = [None, self.index0, self.index1, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = [torch.randn(2, 1280, 8, 8)]
self.run_test(
TestModule(),
input,
)

def test_index_one_SD_unsqueeze_four_dim(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.tensor(
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
)
self.index1 = self.index0.unsqueeze(0).T.long()
super().__init__()

def forward(self, x):
indices = [None, None, self.index1, self.index1]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = [torch.randn(2, 1280, 8, 8)]
self.run_test(
TestModule(),
input,
)

def test_index_zero_index_one_index_two_SD_unsqueeze_four_dim_broadcast(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.tensor(
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
)
self.index1 = self.index0.unsqueeze(0).T.long()
super().__init__()

def forward(self, x):
indices = [None, None, self.index0, self.index1]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = [torch.randn(2, 1280, 8, 8)]
self.run_test(
TestModule(),
input,
)

def test_index_zero_index_one_index_four_dim_non_continuous(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.tensor([0, 0, 1, 1])
self.index1 = torch.tensor([0, 0, 1, 1])
super().__init__()

def forward(self, x):
indices = [None, self.index0, None, self.index1]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = [torch.randn(2, 4, 4, 2)]
self.run_test(
TestModule(),
input,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 059d09a

Please sign in to comment.