Replies: 1 comment 1 reply
-
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
My current understanding is that:
__cuda_array_interface__
or__dlpack__()+__dlpack_device__()
(works in-place if used properly);jax.array_ref
on an JAX array to obtain a mutableArrayRef
;A clarification first: Some external systems, such as google-deepmind/mujoco_warp, instead of importing external memory, let JAX allocate and manage arrays for tighter and smoother integration. But in my use case, I cannot go this way, since memory allocated by JAX has unmet property for the external system I'm trying to interface with. Another solution is to maintain a copy at JAX side, but this leads to unnecessary memory consumption and copy overhead if the array is large.
Here I have several confusions:
jax.array_ref
copies thejax.Array
internally (understandable to maintain the fundamental immutable semantics), preventing me from obtaining a mutable reference to imported array in-place. A hack is to callfrom jax._src.state.types import AbstractRef; ref = jax.ArrayRef(AbstractRef(arr.aval), arr)
, and it does work, but using internal API is what I'd like to avoid. So here are two questions:ArrayRef
cannot be passed to FFI call directly (some confusing state discharge error). In this case, how to pass the exact array to FFI function?jax.ref.freeze
leads to another copy, while passing a jax.Array has no guarantee that the device pointer received by FFI is identical to the intended one. Since the array is external to JAX, I can, of course, pass it as context outside JAX, but I worry about uncaptured implicit dependency leading to ordering issue.Beta Was this translation helpful? Give feedback.
All reactions