Skip to content

Unimplemented primitive in Pallas GPU lowering: scatter. #31876

@chaoming0625

Description

@chaoming0625

Description

Is it possible to run this simple kernel successfully?

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import pallas as pl


def kernel(
    diags_ref,
    solves_ref,
    lowers_ref,
    indices_ref,
    out_solve_ref,
):
    i_neuron = pl.program_id(0)

    diags = pl.load(diags_ref, i_neuron)
    solves = pl.load(solves_ref, i_neuron)
    lowers = pl.load(lowers_ref, pl.dslice(None))

    lowers = lowers.at[0].set(0.0)
    lower_effect = -lowers / diags
    solve_effect = solves / diags
    for i in range(indices_ref.shape[0]):
        index = pl.load(indices_ref, i)
        solve_effect = lower_effect * solve_effect[index] + solve_effect
        lower_effect = lower_effect * lower_effect[index]

    pl.store(out_solve_ref, i_neuron, solve_effect)


@jax.jit
def run_kernel(diags, solves, lowers, indices):
    return pl.pallas_call(
        kernel,
        out_shape=jax.ShapeDtypeStruct(diags.shape, diags.dtype),
        grid=(diags.shape[0],),
    )(diags, solves, lowers, indices)


diags = jnp.asarray(np.random.randn(1024, 128).astype(np.float32))
solves = jnp.asarray(np.random.randn(1024, 128).astype(np.float32))
lowers = jnp.asarray(np.random.randn(128).astype(np.float32))
indices = jnp.asarray(np.random.randint(0, 128, size=(7, 128)).astype(np.int32))
run_kernel(diags, solves, lowers, indices)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  1.26.4
python: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct  4 2024, 13:27:36) [GCC 11.2.0]
device info: NVIDIA GeForce RTX 3080 Ti Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='LeiShen-of-WCM', release='6.6.87.2-microsoft-standard-WSL2', version='#1 SMP PREEMPT_DYNAMIC Thu Jun  5 18:30:46 UTC 2025', machine='x86_64')
$ nvidia-smi
Wed Sep 17 21:56:53 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.04              Driver Version: 576.52         CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3080 ...    On  |   00000000:01:00.0  On |                  N/A |
| N/A   46C    P0             35W /  120W |    2155MiB /  16384MiB |      4%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A           23647      C   /python3.12                           N/A      |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions