Pallas outside python #24453
Unanswered
krzysztofrusek
asked this question in
Q&A
Replies: 1 comment
-
It works with tf saved model, I am posting an example for future readers, Modelimport jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import tensorflow as tf
from jax.experimental import jax2tf
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
my_model = tf.Module()
# Save a function that can take scalar inputs.
my_model.f = tf.function(jax2tf.convert(add_vectors), autograph=False,
input_signature=[tf.TensorSpec([8], tf.float32), tf.TensorSpec([8], tf.float32)])
tf.saved_model.save(my_model, 'mod',options=tf.saved_model.SaveOptions(experimental_custom_gradients=False)) Serverimport tensorflow as tf
restored_model = tf.saved_model.load('./mod')
x = tf.convert_to_tensor([1,2,3,4,5,6,7,8], dtype=tf.float32)
y = tf.convert_to_tensor([1,2,3,4,5,6,7,8], dtype=tf.float32)
print(restored_model.f(x, y)) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
hello, is there an option to run Pallas kernels (triton or mosaic-gpu) outside python ecosystem?
Another related discussion #20508
Beta Was this translation helpful? Give feedback.
All reactions