-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Is it possible to run this simple kernel successfully?
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import pallas as pl
def kernel(
diags_ref,
solves_ref,
lowers_ref,
indices_ref,
out_solve_ref,
):
i_neuron = pl.program_id(0)
diags = pl.load(diags_ref, i_neuron)
solves = pl.load(solves_ref, i_neuron)
lowers = pl.load(lowers_ref, pl.dslice(None))
lowers = lowers.at[0].set(0.0)
lower_effect = -lowers / diags
solve_effect = solves / diags
for i in range(indices_ref.shape[0]):
index = pl.load(indices_ref, i)
solve_effect = lower_effect * solve_effect[index] + solve_effect
lower_effect = lower_effect * lower_effect[index]
pl.store(out_solve_ref, i_neuron, solve_effect)
@jax.jit
def run_kernel(diags, solves, lowers, indices):
return pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(diags.shape, diags.dtype),
grid=(diags.shape[0],),
)(diags, solves, lowers, indices)
diags = jnp.asarray(np.random.randn(1024, 128).astype(np.float32))
solves = jnp.asarray(np.random.randn(1024, 128).astype(np.float32))
lowers = jnp.asarray(np.random.randn(128).astype(np.float32))
indices = jnp.asarray(np.random.randint(0, 128, size=(7, 128)).astype(np.int32))
run_kernel(diags, solves, lowers, indices)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.5.0
jaxlib: 0.5.0
numpy: 1.26.4
python: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct 4 2024, 13:27:36) [GCC 11.2.0]
device info: NVIDIA GeForce RTX 3080 Ti Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='LeiShen-of-WCM', release='6.6.87.2-microsoft-standard-WSL2', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 5 18:30:46 UTC 2025', machine='x86_64')
$ nvidia-smi
Wed Sep 17 21:56:53 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.04 Driver Version: 576.52 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3080 ... On | 00000000:01:00.0 On | N/A |
| N/A 46C P0 35W / 120W | 2155MiB / 16384MiB | 4% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 23647 C /python3.12 N/A |
+-----------------------------------------------------------------------------------------+
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working