Skip to content

Commit 391cbc1

Browse files
committed
Remove lambdas
1 parent ec7f0b6 commit 391cbc1

File tree

5 files changed

+102
-61
lines changed

5 files changed

+102
-61
lines changed

jax_triton/pallas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Module for pallas, a jaxpr "dialect" for Triton."""
1616
from jax_triton.pallas.core import BlockSpec
17-
from jax_triton.pallas.core import Config
17+
from jax_triton.pallas.core import KernelConfig
1818
from jax_triton.pallas.pallas_call import pallas_call
1919
from jax_triton.pallas.pallas_call import pallas_call_p
2020
from jax_triton.pallas.primitives import atomic_add

jax_triton/pallas/core.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import functools
1919
from functools import partial
2020

21-
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
21+
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union
2222

2323
import jax.numpy as jnp
2424
from jax._src import api_util
@@ -95,6 +95,18 @@ class GridSpec:
9595

9696
Platform = str
9797

98+
99+
@dataclasses.dataclass
100+
class KernelConfig:
101+
in_specs: Optional[Sequence[Optional[BlockSpec]]] = None
102+
out_specs: Optional[Sequence[Optional[BlockSpec]]] = None
103+
grid: Optional[Union[Grid, int]] = None
104+
meta: dict[str, Any] = dataclasses.field(default_factory=dict)
105+
compiler_params: dict[Platform, dict[str, Any]] = dataclasses.field(default_factory=dict)
106+
107+
def replace(self, *args, **kwargs):
108+
return dataclasses.replace(self, *args, **kwargs)
109+
98110
@dataclasses.dataclass
99111
class Config:
100112
meta: dict[str, Any]

jax_triton/pallas/pallas_call.py

Lines changed: 71 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -306,24 +306,16 @@ def _compute_spec(config: Config, spec: MaybeSpec,
306306
spec = spec(**config.meta)
307307
return spec
308308

309-
def specialize_kernel(config: Config,
309+
def specialize_kernel(config: pallas_core.KernelConfig,
310310
func: Callable,
311-
grid: Optional[pallas_core.Grid],
312311
name: Optional[str],
313-
in_specs: Optional[list[Optional[BlockSpec]]],
314-
out_specs: Optional[list[Optional[BlockSpec]]],
315312
in_avals: tuple[jax_core.ShapedArray, ...],
316313
out_avals: tuple[jax_core.ShapedArray, ...],
317314
in_tree: tree_util.PyTreeDef,
318315
compiler_params: dict[str, Any]
319316
) -> tuple[SpecializedKernel, ...]:
320-
specialized_grid = grid
321-
if callable(specialized_grid):
322-
specialized_grid = specialized_grid(**config.meta)
323-
specialized_grid = pallas_core.preprocess_grid(specialized_grid)
324-
specialized_in_specs = map(partial(_compute_spec, config), in_specs)
325-
specialized_out_specs = map(partial(_compute_spec, config), out_specs)
326-
if specialized_grid == ():
317+
grid = config.grid
318+
if grid == ():
327319
in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
328320
for arg in in_avals]
329321
out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
@@ -333,42 +325,76 @@ def specialize_kernel(config: Config,
333325
state.shaped_array_ref(
334326
pallas_core.compute_shape_from_block_spec(block_spec, aval.shape),
335327
aval.dtype)
336-
for block_spec, aval in zip(specialized_in_specs, in_avals)]
328+
for block_spec, aval in zip(config.in_specs, in_avals)]
337329
out_ref_avals = [
338330
state.shaped_array_ref(
339331
pallas_core.compute_shape_from_block_spec(block_spec, aval.shape),
340332
aval.dtype)
341-
for block_spec, aval in zip(specialized_out_specs, out_avals)]
342-
in_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid),
343-
specialized_in_specs)
344-
out_block_mappings = map(partial(pallas_core.convert_block_spec_to_block_mapping, specialized_grid),
345-
specialized_out_specs)
346-
grid_spec = pallas_core.GridSpec(specialized_grid, (*in_block_mappings, *out_block_mappings), ())
333+
for block_spec, aval in zip(config.out_specs, out_avals)]
334+
in_block_mappings = map(
335+
partial(pallas_core.convert_block_spec_to_block_mapping, grid),
336+
config.in_specs)
337+
out_block_mappings = map(
338+
partial(pallas_core.convert_block_spec_to_block_mapping, grid),
339+
config.out_specs)
340+
grid_spec = pallas_core.GridSpec(grid, (*in_block_mappings, *out_block_mappings), ())
347341
jaxpr, consts, out_tree = tracing_utils.initial_style_open_jaxpr(
348342
func, in_tree, tuple((*in_ref_avals, *out_ref_avals)), "pallas_call", **config.meta)
349343
return SpecializedKernel("foo", jaxpr, len(consts), grid_spec,
350344
dict(compiler_params, **config.compiler_params)), consts, out_tree
351345

352-
def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False,
346+
def _canonicalize_kernel_config(
347+
maybe_kernel_config: Optional[pallas_core.KernelConfig],
348+
in_avals: Sequence[jax_core.AbstractValue],
349+
out_avals: Sequence[jax_core.AbstractValue],
350+
in_specs: Optional[Sequence[Optional[BlockSpec]]],
351+
out_specs: Optional[Sequence[Optional[BlockSpec]]],
352+
grid: Optional[Union[Grid, int]],
353+
) -> pallas_core.KernelConfig:
354+
if not maybe_kernel_config:
355+
config = pallas_core.KernelConfig(in_specs=in_specs, out_specs=out_specs, grid=grid)
356+
else:
357+
config = maybe_kernel_config
358+
grid = maybe_kernel_config.grid
359+
grid, in_specs, out_specs = config.grid, config.in_specs, config.out_specs
360+
grid = pallas_core.preprocess_grid(grid)
361+
if in_specs is not None and not isinstance(in_specs, (tuple, list)):
362+
in_specs = (in_specs,)
363+
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
364+
out_specs = (out_specs,)
365+
if in_specs is None:
366+
in_specs = [None] * len(in_avals)
367+
if out_specs is None:
368+
out_specs = [None] * len(out_avals)
369+
return config.replace(grid=grid, in_specs=in_specs, out_specs=out_specs)
370+
371+
def pallas_call(f: Callable, out_shape: Any, *,
353372
grid: Optional[Grid] = None,
373+
config: Optional[pallas_core.KernelConfig] = None,
354374
in_specs: Optional[Sequence[Optional[BlockSpec]]] = None,
355375
out_specs: Optional[Sequence[Optional[BlockSpec]]] = None,
356376
input_output_aliases: Dict[int, int] = {},
357377
interpret: bool = False,
358378
name: Optional[str] = None,
359-
autotuning_configs: Optional[list[Config]] = None,
379+
autotuning_configs: Optional[Sequence[pallas_core.KernelConfig]] = None,
380+
debug: bool = False,
360381
**compiler_params: Any):
382+
if config is not None:
383+
if grid is not None or in_specs is not None or out_specs is not None:
384+
raise ValueError("Cannot specify both `config` and any of `grid`, "
385+
"`in_specs`, or `out_specs`.")
386+
if autotuning_configs is not None:
387+
raise ValueError("Cannot specify both `config` and `autotuning_configs`")
388+
if autotuning_configs is not None:
389+
if grid is not None or in_specs is not None or out_specs is not None:
390+
raise ValueError("Cannot specify both `autotuning_configs` and any of `grid`, "
391+
"`in_specs`, or `out_specs`.")
361392
singleton = False
362393
if not isinstance(out_shape, (tuple, list)):
363394
out_shape = (out_shape,)
364395
singleton = True
365396
if not isinstance(out_shape, tuple):
366397
out_shape = tuple(out_shape)
367-
if in_specs is not None and not isinstance(in_specs, (tuple, list)):
368-
in_specs = (in_specs,)
369-
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
370-
out_specs = (out_specs,)
371-
372398
if not name:
373399
name = f.__name__ if hasattr(f, "__name__") else "unnamed"
374400

@@ -382,29 +408,32 @@ def wrapped(*args):
382408
for a in flat_args)
383409
flat_out_avals = tuple(jax_core.ShapedArray(a.shape, a.dtype)
384410
for a in flat_out_shapes)
411+
canonicalized_configs = []
412+
if autotuning_configs is None:
413+
canonicalized_configs.append(_canonicalize_kernel_config(config,
414+
flat_in_avals,
415+
flat_out_avals,
416+
in_specs,
417+
out_specs,
418+
grid))
419+
else:
420+
canonicalized_configs.extend(map(partial(_canonicalize_kernel_config,
421+
in_avals=flat_in_avals,
422+
out_avals=flat_out_avals,
423+
in_specs=in_specs,
424+
out_specs=out_specs,
425+
grid=grid),
426+
autotuning_configs))
385427
kernels = []
386-
flat_in_specs = in_specs
387-
flat_out_specs = out_specs
388-
if flat_in_specs is None:
389-
flat_in_specs = [None] * len(flat_in_avals)
390-
if flat_out_specs is None:
391-
flat_out_specs = [None] * len(flat_out_avals)
392428
all_consts = []
393-
if autotuning_configs is None:
429+
if len(canonicalized_configs) == 0:
430+
raise ValueError("Cannot pass in empty autotuning configs")
431+
for canonicalized_config in canonicalized_configs:
394432
specialized_kernel, consts, jaxpr_out_tree = specialize_kernel(
395-
Config({}, {}), f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals,
433+
canonicalized_config, f, name, flat_in_avals,
396434
flat_out_avals, jaxpr_in_tree, compiler_params)
397435
kernels.append(specialized_kernel)
398436
all_consts.extend(consts)
399-
else:
400-
if len(autotuning_configs) == 0:
401-
raise ValueError("Cannot pass in empty autotuning configs")
402-
for config in autotuning_configs:
403-
specialized_kernel, consts, jaxpr_out_tree = specialize_kernel(
404-
config, f, grid, name, flat_in_specs, flat_out_specs, flat_in_avals, flat_out_avals,
405-
jaxpr_in_tree, compiler_params)
406-
kernels.append(specialized_kernel)
407-
all_consts.extend(consts)
408437
if all_consts:
409438
raise NotImplementedError("Cannot handle consts.")
410439
del jaxpr_out_tree

jax_triton/pallas/triton_ir_lowering.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -837,9 +837,11 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
837837
if debug:
838838
print(kernel.jaxpr)
839839
print(kernel.grid_spec)
840-
compiler_params = kernel.compiler_params
841-
num_warps = compiler_params.get("num_warps", 4)
842-
num_stages = compiler_params.get("num_stages", 3)
840+
compiler_params = dict(kernel.compiler_params)
841+
num_warps = compiler_params.pop("num_warps", 4)
842+
num_stages = compiler_params.pop("num_stages", 3)
843+
if compiler_params:
844+
raise ValueError(f"Invalid compiler params: {compiler_params}")
843845
compilation_result = compile_jaxpr(kernel.jaxpr, kernel.num_consts,
844846
tuple((*in_shapes, *out_shapes)),
845847
kernel.grid_spec, kernel.name, num_warps, num_stages)

tests/pallas_test.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -857,10 +857,10 @@ def test_basic_autotuning(self):
857857

858858
@functools.partial(
859859
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32),
860-
grid=lambda block_size: 8 // block_size,
861860
autotuning_configs=[
862-
pl.Config(dict(block_size=2), {}),
863-
pl.Config(dict(block_size=4), {}),
861+
pl.KernelConfig(meta=dict(block_size=block_size),
862+
grid=8 // block_size)
863+
for block_size in [1, 2, 4, 8]
864864
])
865865
def add_one(x_ref, o_ref, *, block_size):
866866
idx = pl.program_id(0) * block_size + jnp.arange(block_size)
@@ -873,18 +873,16 @@ def test_basic_autotuning_with_block_spec(self):
873873

874874
@functools.partial(
875875
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32),
876-
grid=lambda block_size: 8 // block_size,
877-
in_specs=[
878-
lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)),
879-
],
880-
out_specs=[
881-
lambda block_size: pl.BlockSpec(lambda i: i, (block_size,)),
882-
],
883876
autotuning_configs=[
884-
pl.Config(dict(block_size=1), {}),
885-
pl.Config(dict(block_size=2), {}),
886-
pl.Config(dict(block_size=4), {}),
887-
pl.Config(dict(block_size=8), {}),
877+
pl.KernelConfig(meta=dict(block_size=block_size),
878+
in_specs=[
879+
pl.BlockSpec(lambda i: i, (block_size,))
880+
],
881+
out_specs=[
882+
pl.BlockSpec(lambda i: i, (block_size,))
883+
],
884+
grid=8 // block_size)
885+
for block_size in [1, 2, 4, 8]
888886
],
889887
debug=True)
890888
def add_one(x_ref, o_ref, *, block_size):

0 commit comments

Comments
 (0)