Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dot and BatchedDot in PyTensor #878

Merged
merged 13 commits into from
Jul 18, 2024
12 changes: 12 additions & 0 deletions pytensor/link/pytorch/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch

from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.math import Dot


@pytorch_funcify.register(Dot)
def pytorch_funcify_Dot(op, **kwargs):
Copy link
Member

@ricardoV94 ricardoV94 Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to import this file from pytorch.dispatch.__init__ for it to be registered (the test is failing in the CI). But Dot is not defined in nlinalg, so we should put it in dispatch/match.py? Same for the test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I based it off the JAX link. If you take a look at pytensor/link/jax/dispatch/nlinalg.py you will see Max, Argmax, and Dot Ops from math in there. Do you want me to separate them out for JAX too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also put the Argmax I am implementing in pytorch/dispatch/math.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in general we want to keep it more or less mirrored with the file structure where they are defined. Although our tensor/basic.py and tensor/math.py are in need of being split of as they have way too many lines

def dot(x, y):
return torch.matmul(x, y)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

return dot
30 changes: 30 additions & 0 deletions tests/link/pytorch/test_nlinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np

from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import matrix, scalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py


def test_tensor_basics():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A.tag.test_value = np.array([[6, 3], [3, 0]], dtype=config.floatX)
alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX)

# 1D * 2D * 1D
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

# 2D * 2D
out = A.dot(A * alpha) + beta * A
fgraph = FunctionGraph([A, alpha, beta], [out])
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Loading