Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement @as_jax_op to wrap a JAX function for use in PyTensor #1120

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

jdehning
Copy link

@jdehning jdehning commented Dec 12, 2024

Description

Add a decorator that transforms a JAX function such that it can be used in PyTensor. Shape and dtype inference works automatically and input and output can be any nested python structure (e.g. Pytrees). Furthermore, using a transformed function as an argument for another transformed function should also work.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1120.org.readthedocs.build/en/1120/

@jdehning
Copy link
Author

jdehning commented Dec 12, 2024

I have a question, where should I put the @as_jax_op. Currently, it is in a new file pytensor/link/jax/ops.py. Does that make sense? Also, how should one access it? Only by calling pytensor.link.jax.ops.as_jax_op? Or include it in a __init__.py such that pytensor.as_jax_op works?

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 12, 2024

We can put in init as long as imports work in a way that jax is still optional for Pytensor users (obviously calling the decorator can raise if it's not installed, hopefully with an informative message)

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks sweet. I'll do a more careful review later, just skimmed through and annotated some thoughts

self.num_inputs = len(inputs)

# Define our output variables
outputs = [pt.as_tensor_variable(type()) for type in self.output_types]
Copy link
Member

@ricardoV94 ricardoV94 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use jax machinery to infer the output types from the input types? Can we created TraceDArrays (or whatever they're called) and pass them through the function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scrap that, JAX doesn't let you trace arrays without unknown shape

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trace the shape through the JAX function in line 119 of the file. It won't work for unknown shape. But if one specifies the shape at the beginning of a graph, i.e. x = pm.Normal("x", shape=(3,)), and it loses static shape information afterwards, for instance because of a pt.cumsum, line 99 (pytensor.compile.builders.infer_shape) will be able to infer the shape. But that is a good comment, I will raise an error if pytensor.compile.builders.infer_shape isn't able to infer the shape. I think it makes sense to only use this wrapper if the shape information is known.

Copy link
Author

@jdehning jdehning Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see a point where it will lead to problems: If there is an input x = pm.Data("x", shape=(None,), value= np.array([0., 0])): in the first run, it will work, as pytensor.compile.builders.infer_shape will infer the shape as (2,), but if one changes with x.set_value(np.array([0., 0, 0])) the shape of x, it will lead to an error, as the Pytensor Op has been created with an explicit shape. I could simply add a parameter to as_jax_op to force all output shapes to None, then it should work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will write more tests, then it will be clearer what I mean

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can't use shape unless it's static. Ideally it shouldn't fail for unknown shapes, but then the user has to tell user the output types.

We can allow the user to specify a make_node callable? That way it can be made to work with different dtypes/ndims if the jax function handles those fine

return (result,) # Pytensor requires a tuple here

# vector-jacobian product Op
class VJPSolOp(Op):
Copy link
Member

@ricardoV94 ricardoV94 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a nice follow up would be to also create a "ValueAndGrad" version of the Op that gets introduced in rewrites when both the Op and the VJP of Op (or JVP) are in the final graph.

This need not be a blocker for this PR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see exactly what you mean. Is ValueAndGrad used by Pytensor? I searched the codebase but didn't find a mention of it. Does it have to do with implementing L_op? I haven't really understood the difference between it and grad

Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX has the value and grad concept to more optimally compute both together. PyTensor doesn't have that concept because everything is lazy but we can exploit it during the rewrite phase.

If a user compiles a function that includes both forward and gradient of the same wrapped JAX Op, we could replace it by a third Op whose perform implementation requests jax to compute both.

This is not relevant when the autodiff is done in JAX, but it's relevant when it's done in PyTensor

jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax,
)

@jax_funcify.register(SolOp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can dispatch on the base class just once?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? This jax.funcify is once registering SolOp, once VJPSolOp. You mean, one could include the gradient calculation in SolOp?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can define SolOp class outside the decorator and dispatch on that.

Then the decorator can return a subclass of that and you don't need to bother dispatching because the base class dispatch will cover it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I didn't think of that. Have a look at whether I implemented it like you had envisioned

pytensor/link/jax/ops.py Outdated Show resolved Hide resolved
jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax,
)

@jax_funcify.register(SolOp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can define SolOp class outside the decorator and dispatch on that.

Then the decorator can return a subclass of that and you don't need to bother dispatching because the base class dispatch will cover it

@ricardoV94
Copy link
Member

Big level picture. What's going on with the flattening of inputs and why is it needed?

@jdehning
Copy link
Author

Big level picture. What's going on with the flattening of inputs and why is it needed?

To be able to wrap JAX function that accept pytrees as input. pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve) works, and can be used in the same way as one would use the original diffrax.diffeqsolve.

@ricardoV94
Copy link
Member

Big level picture. What's going on with the flattening of inputs and why is it needed?

To be able to wrap JAX function that accept pytrees as input. pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve) works, and can be used in the same way as one would use the original diffrax.diffeqsolve.

And if I have a matrix input function will this work or expect it to be a vector instead?

@jdehning
Copy link
Author

Big level picture. What's going on with the flattening of inputs and why is it needed?

To be able to wrap JAX function that accept pytrees as input. pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve) works, and can be used in the same way as one would use the original diffrax.diffeqsolve.

And if I have a matrix input function will this work or expect it to be a vector instead?

It will work, it doesn't change pytensor.Variables, a matrix will stay a matrix. What it does, is to flatten nested python structure, e.g. {"a": tensor_a, "b": [tensor_b, tensor_c]} becomes [tensor_a, tensor_b, tensor_c] (and a treedef object which saves the structure of the tree), where tensor_x are three different tensors of potentially different shape and dtype. As pytensor operators accept a list of tensors as input, the flattened version can be used to define our op. The shapes of the tensors aren't changed. This is also basically how operators in JAX are written, see the second code box in this paragraph: https://jax.readthedocs.io/en/latest/autodidax.html#pytrees-and-flattening-user-functions-inputs-and-outputs

@jdehning
Copy link
Author

I would begin in parallel to write an example notebook. I opened an issue here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement helper @as_jax_op to wrap JAX functions in PyTensor
2 participants