Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/gather import #1843

Merged
merged 11 commits into from
Jun 3, 2024
Prev Previous commit
Next Next commit
Create PyTorch script for gather
agelas committed May 31, 2024
commit f9205f826cd5295ab1353f7037624b15fdd0e1ce
20 changes: 10 additions & 10 deletions crates/burn-import/onnx-tests/tests/gather/gather.onnx
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
pytorch2.2.2:�
a
onnx::GatherElements_0
onnx::GatherElements_12/GatherElements"GatherElements*
pytorch2.1.1:�
A
onnx::Gather_0
onnx::Gather_12/Gather"Gather*
axis�
main_graphZ(
onnx::GatherElements_0
main_graphZ
onnx::Gather_0


Z(
onnx::GatherElements_1


Z
onnx::Gather_1


b
2

16 changes: 8 additions & 8 deletions crates/burn-import/onnx-tests/tests/gather/gather.py
Original file line number Diff line number Diff line change
@@ -11,8 +11,8 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x, index):
x = torch.gather(x, 1, index)
return x
gathered = torch.index_select(x, 1, index)
return gathered


def main():
@@ -24,19 +24,19 @@ def main():
model.eval()
device = torch.device("cpu")
onnx_name = "gather.onnx"
dummy_input = torch.randn(2, 2, device=device)
dummy_index = torch.randint(high=2, size=(2, 2), device=device, dtype=torch.int64)

dummy_input = torch.randn(2, 3, device=device)
dummy_index = torch.tensor([0, 2], device=device, dtype=torch.int64)

torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0],
[3.0, 4.0]])
test_index = torch.tensor([[0, 0],
[1, 0]])
test_input = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
test_index = torch.tensor([0, 2], dtype=torch.int64)

print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)