diff --git a/jax/_src/pallas/mosaic/sc_core.py b/jax/_src/pallas/mosaic/sc_core.py index 1f53daf96da3..0d95f24431ea 100644 --- a/jax/_src/pallas/mosaic/sc_core.py +++ b/jax/_src/pallas/mosaic/sc_core.py @@ -16,8 +16,9 @@ from __future__ import annotations import collections -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import dataclasses +import functools import math from typing import Any, TypeAlias @@ -340,8 +341,9 @@ def wrapper(*args): @pallas_core.core_map(mesh, **kwargs) def _(): return pallas_primitives.run_scoped( - lambda *scratch_refs: body(*arg_refs, *out_refs, *scratch_refs), - *scratch_shapes, + functools.partial(body, *arg_refs, *out_refs), + *scratch_shapes if isinstance(scratch_shapes, Sequence) else (), + **scratch_shapes if isinstance(scratch_shapes, Mapping) else {}, ) outs = jax.tree.map(lambda ref: ref[...], out_refs) diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 9b027a4a22da..491c51e561e8 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -1145,10 +1145,12 @@ def test_barrier_via_pallas_call(self): shape=(mesh.num_subcores, vec_dim), dtype=jnp.uint32 ), out_specs=pl.BlockSpec((1, vec_dim), lambda i: (i, 0)), - scratch_shapes=[ - pltpu.VMEM_SHARED((mesh.num_subcores, vec_dim), jnp.uint32), - pltpu.VMEM((vec_dim,), jnp.uint32), - ], + scratch_shapes=dict( + shared_ref=pltpu.VMEM_SHARED( + (mesh.num_subcores, vec_dim), jnp.uint32 + ), + vmem_ref=pltpu.VMEM((vec_dim,), jnp.uint32), + ), ) def kernel(o_ref, shared_ref, vmem_ref): subcore_id = pl.program_id(0)