Skip to content

Commit ec7f0b6

Browse files
committed
Working!
1 parent 0cc62dc commit ec7f0b6

File tree

6 files changed

+371
-224
lines changed

6 files changed

+371
-224
lines changed

jax_triton/pallas/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Module for pallas, a jaxpr "dialect" for Triton."""
1616
from jax_triton.pallas.core import BlockSpec
17+
from jax_triton.pallas.core import Config
1718
from jax_triton.pallas.pallas_call import pallas_call
1819
from jax_triton.pallas.pallas_call import pallas_call_p
1920
from jax_triton.pallas.primitives import atomic_add

jax_triton/pallas/core.py

Lines changed: 17 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from jax._src.util import weakref_lru_cache, safe_map, safe_zip
3232
from jax._src.state.types import AbstractRef
3333

34+
from jax_triton.pallas import tracing_utils
35+
3436
map, unsafe_map = safe_map, map
3537
zip, unsafe_zip = safe_zip, zip
3638

@@ -91,23 +93,18 @@ class GridSpec:
9193

9294
replace = dataclasses.replace
9395

94-
@weakref_lru_cache
95-
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
96-
primitive_name: Optional[str] = None):
97-
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
98-
lu.wrap_init(fun), in_tree)
99-
debug_info = pe.debug_info(fun, in_tree, out_tree_thunk, False,
100-
primitive_name or "<unknown>")
101-
jaxpr, consts = _initial_style_flat_jaxpr(wrapped_fun, in_avals,
102-
debug_info=debug_info)
103-
return jaxpr, consts, out_tree_thunk()
104-
105-
def _initial_style_flat_jaxpr(fun: lu.WrappedFun, in_avals,
106-
debug_info: Optional[jax_core.DebugInfo] = None
107-
) -> tuple[jax_core.Jaxpr, list[Any]]:
108-
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals, debug_info)
109-
jaxpr = for_loop._hoist_consts_to_refs(jaxpr)
110-
return jaxpr, consts
96+
Platform = str
97+
98+
@dataclasses.dataclass
99+
class Config:
100+
meta: dict[str, Any]
101+
compiler_params: dict[Platform, dict[str, Any]]
102+
103+
def to_string(self, platform: str) -> str:
104+
compiler_params = self.compiler_params.get(platform, {})
105+
return "-".join([*(f"{k}_{v}" for k, v in self.meta.items()),
106+
*(f"{k}_{v}" for k, v in compiler_params.items())])
107+
111108

112109
def preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid:
113110
if grid is None:
@@ -141,64 +138,8 @@ def compute_shape_from_block_spec(block_spec: Optional[BlockSpec],
141138

142139
@dataclasses.dataclass
143140
class SpecializedKernel:
141+
name: str
144142
jaxpr: jax_core.Jaxpr
143+
num_consts: int
145144
grid_spec: GridSpec
146-
147-
@dataclasses.dataclass(frozen=True)
148-
class Kernel:
149-
func: lu.WrappedFun
150-
name: Optional[str]
151-
grid: Optional[Grid]
152-
in_specs: Optional[list[Optional[BlockSpec]]]
153-
out_specs: Optional[list[Optional[BlockSpec]]]
154-
155-
def __post_init__(self):
156-
if self.grid is None:
157-
if self.in_specs is not None:
158-
raise ValueError("Cannot specify `in_specs` with a `None` grid.")
159-
if self.out_specs is not None:
160-
raise ValueError("Cannot specify `out_specs` with a `None` grid.")
161-
162-
def get_name(self) -> str:
163-
return extract_function_name(self.func, self.name)
164-
165-
def specialize(self,
166-
in_avals: tuple[AbstractRef, ...],
167-
out_avals: tuple[AbstractRef, ...],
168-
in_tree: tree_util.PyTreeDef
169-
) -> tuple[SpecializedKernel, ...]:
170-
grid = preprocess_grid(self.grid)
171-
in_specs = self.in_specs
172-
out_specs = self.out_specs
173-
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
174-
out_specs = (out_specs,)
175-
if out_specs is not None and not isinstance(out_specs, tuple):
176-
out_specs = tuple(out_specs)
177-
if in_specs is None:
178-
in_specs = [None] * len(in_avals)
179-
if out_specs is None:
180-
out_specs = [None] * len(out_avals)
181-
if grid == ():
182-
in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
183-
for arg in in_avals]
184-
out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
185-
for arg in out_avals]
186-
else:
187-
in_ref_avals = [
188-
state.shaped_array_ref(
189-
compute_shape_from_block_spec(block_spec, aval.shape),
190-
aval.dtype)
191-
for block_spec, aval in zip(in_specs, in_avals)]
192-
out_ref_avals = [
193-
state.shaped_array_ref(
194-
compute_shape_from_block_spec(block_spec, aval.shape),
195-
aval.dtype)
196-
for block_spec, aval in zip(out_specs, out_avals)]
197-
in_block_mappings = map(partial(convert_block_spec_to_block_mapping, grid),
198-
in_specs)
199-
out_block_mappings = map(partial(convert_block_spec_to_block_mapping, grid),
200-
out_specs)
201-
grid_spec = GridSpec(grid, (*in_block_mappings, *out_block_mappings), ())
202-
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
203-
self.func, in_tree, tuple((*in_ref_avals, *out_ref_avals)))
204-
return [SpecializedKernel(jaxpr, grid_spec)], consts, out_tree
145+
compiler_params: dict[str, Any]

0 commit comments

Comments
 (0)