You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am currently trying to extend the JAX Python package with Rust bindings. My goal is to adapt the extending JAX with C++ tutorial to Rust, if possible, using PyO3 and Maturin.
Unfortunately, generating XLA-compatible code heavily relies on C++ features like macros and templates, for which it seems impossible to automatically generate bindings (through autocxx, for example). The partial solution I have come up on is to:
write my function (rms_norm) in Rust;
export it to C++;
wrap the exported function with XLA code inside C++;
export the wrapped function (extern "C") back to Rust;
and register the function pointer in the module as a PyCapsule.
I could probably stop at step (3) and use, e.g., nanobind to generate the Python module from C++, but my goal is to preferably stick with PyO3 and Maturin, as they are much more convenient in my case :-)
The Rust code looks as follows:
fnrms_norm(eps:f32,x:&[f32],y:&mut[f32]){/* actual implementation */}#[cxx::bridge]mod ffi {extern"Rust"{// Expose to C++ our Rust functionfnrms_norm(eps:f32,x:&[f32],y:&mut[f32]) -> ();}unsafeextern"C++"{include!("rms-norm/include/ffi.h");typeXLA_FFI_Error;typeXLA_FFI_CallFrame;// This is the C++ XLA compatible wrapper around our 'rms_norm' Rust functionunsafefnRmsNorm(call_frame:*mutXLA_FFI_CallFrame) -> *mutXLA_FFI_Error;}}#[pymodule]fn_rms_norm(m:&Bound<'_,PyModule>) -> PyResult<()>{let name = CString::new("rms_norm").unwrap();let f:unsafefn(*mut ffi::XLA_FFI_CallFrame) -> *mut ffi::XLA_FFI_Error = ffi::RmsNorm;
m.add("rms_norm",PyCapsule::new(m.py(), f,Some(name))?)?;Ok(())}
When registering the custom FFI call (inside Python), with:
it somewhat later generates a segmentation fault, so my question is: do I pass the pointer to the C++ function correctly?
When reading the C++ example, they register the callable1 as a *const void and store it inside a PyCapsule. I have tried to mimic this, but your PyCapsule constructor doesn't allow explicitly passing a pointer, as it is not Send, so I cast the function to a function pointer.
They also wrap the PyCapsule inside a lambda-like function, meaning they need to call rms_norm() to actually return the PyCapsule, but I don't think this additional level of nesting is necessary, is it? ↩
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi!
I am currently trying to extend the JAX Python package with Rust bindings. My goal is to adapt the extending JAX with C++ tutorial to Rust, if possible, using PyO3 and Maturin.
Unfortunately, generating XLA-compatible code heavily relies on C++ features like macros and templates, for which it seems impossible to automatically generate bindings (through
autocxx
, for example). The partial solution I have come up on is to:rms_norm
) in Rust;extern "C"
) back to Rust;PyCapsule
.I could probably stop at step (3) and use, e.g., nanobind to generate the Python module from C++, but my goal is to preferably stick with PyO3 and Maturin, as they are much more convenient in my case :-)
The Rust code looks as follows:
When registering the custom FFI call (inside Python), with:
it somewhat later generates a segmentation fault, so my question is: do I pass the pointer to the C++ function correctly?
When reading the C++ example, they register the callable1 as a
*const void
and store it inside aPyCapsule
. I have tried to mimic this, but yourPyCapsule
constructor doesn't allow explicitly passing a pointer, as it is notSend
, so I cast the function to a function pointer.Thanks for your help!
You can find the full MWE code here: https://github.com/jeertmans/extending-jax/tree/5297b806c8f434030612875e270e3f598ec0e38d
Footnotes
They also wrap the
PyCapsule
inside a lambda-like function, meaning they need to callrms_norm()
to actually return thePyCapsule
, but I don't think this additional level of nesting is necessary, is it? ↩Beta Was this translation helpful? Give feedback.
All reactions