Replies: 3 comments 5 replies
-
To answer my own question: Simply jit a smaller function with the desired device: pinv = jax.jit(jnp.linalg.pinv, device=jax.devices('cpu')[0])
@jax.jit
@jax.grad
def f(x):
return pinv(x).sum()
%time jax.block_until_ready(f(x))
%timeit jax.block_until_ready(f(x))
CPU times: user 467 ms, sys: 49 ms, total: 516 ms
Wall time: 471 ms
220 ms ± 3.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) |
Beta Was this translation helpful? Give feedback.
-
I've been having similar issues, I ended up wrapping the CPU part in a pure_callback and it seems to work but not totally sure |
Beta Was this translation helpful? Give feedback.
-
In general, I believe the most general way to achieve this is to use Unfortunately we don't have great docs on this just yet, but you can see some examples in the For example: import jax
import jax.numpy as jnp
import numpy as np
cuda_devices = jax.devices('cuda')
cpu_devices = jax.devices('cpu')
gpu_sharding = jax.sharding.SingleDeviceSharding(cuda_devices[0])
cpu_sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
x = jnp.array(np.random.normal(size=(4096, 7, 64)))
@jax.jit
@jax.grad
def f_gpu(x):
x = jax.lax.with_sharding_constraint(x, gpu_sharding)
return jnp.linalg.pinv(x).sum()
%time jax.block_until_ready(f_gpu(x))
%timeit jax.block_until_ready(f_gpu(x))
@jax.jit
@jax.grad
def f_cpu(x):
x = jax.lax.with_sharding_constraint(x, cpu_sharding)
return jnp.linalg.pinv(x).sum()
%time jax.block_until_ready(f_cpu(x))
%timeit jax.block_until_ready(f_cpu(x))
|
Beta Was this translation helpful? Give feedback.
-
As the title suggests, I want to run part of my program on another device. Specifically, since GPUs are quite bad in SVD, I'd like to do the SVD on the CPU. In Tensorflow and PyTorch it is possible to simply annotate part of a function to run on a different device. However, in JAX I get the following behavior:
Using
default_device
:Moving the whole program to CPU
Since I want most of my program to run on the GPU, it is no option to run everything on the CPU. I am aware of host callbacks which could be applied here. Though is there a different solution?
Beta Was this translation helpful? Give feedback.
All reactions