|
31 | 31 | from jax._src.util import weakref_lru_cache, safe_map, safe_zip
|
32 | 32 | from jax._src.state.types import AbstractRef
|
33 | 33 |
|
| 34 | +from jax_triton.pallas import tracing_utils |
| 35 | + |
34 | 36 | map, unsafe_map = safe_map, map
|
35 | 37 | zip, unsafe_zip = safe_zip, zip
|
36 | 38 |
|
@@ -91,23 +93,18 @@ class GridSpec:
|
91 | 93 |
|
92 | 94 | replace = dataclasses.replace
|
93 | 95 |
|
94 |
| -@weakref_lru_cache |
95 |
| -def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, |
96 |
| - primitive_name: Optional[str] = None): |
97 |
| - wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( |
98 |
| - lu.wrap_init(fun), in_tree) |
99 |
| - debug_info = pe.debug_info(fun, in_tree, out_tree_thunk, False, |
100 |
| - primitive_name or "<unknown>") |
101 |
| - jaxpr, consts = _initial_style_flat_jaxpr(wrapped_fun, in_avals, |
102 |
| - debug_info=debug_info) |
103 |
| - return jaxpr, consts, out_tree_thunk() |
104 |
| - |
105 |
| -def _initial_style_flat_jaxpr(fun: lu.WrappedFun, in_avals, |
106 |
| - debug_info: Optional[jax_core.DebugInfo] = None |
107 |
| - ) -> tuple[jax_core.Jaxpr, list[Any]]: |
108 |
| - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals, debug_info) |
109 |
| - jaxpr = for_loop._hoist_consts_to_refs(jaxpr) |
110 |
| - return jaxpr, consts |
| 96 | +Platform = str |
| 97 | + |
| 98 | +@dataclasses.dataclass |
| 99 | +class Config: |
| 100 | + meta: dict[str, Any] |
| 101 | + compiler_params: dict[Platform, dict[str, Any]] |
| 102 | + |
| 103 | + def to_string(self, platform: str) -> str: |
| 104 | + compiler_params = self.compiler_params.get(platform, {}) |
| 105 | + return "-".join([*(f"{k}_{v}" for k, v in self.meta.items()), |
| 106 | + *(f"{k}_{v}" for k, v in compiler_params.items())]) |
| 107 | + |
111 | 108 |
|
112 | 109 | def preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid:
|
113 | 110 | if grid is None:
|
@@ -141,64 +138,8 @@ def compute_shape_from_block_spec(block_spec: Optional[BlockSpec],
|
141 | 138 |
|
142 | 139 | @dataclasses.dataclass
|
143 | 140 | class SpecializedKernel:
|
| 141 | + name: str |
144 | 142 | jaxpr: jax_core.Jaxpr
|
| 143 | + num_consts: int |
145 | 144 | grid_spec: GridSpec
|
146 |
| - |
147 |
| -@dataclasses.dataclass(frozen=True) |
148 |
| -class Kernel: |
149 |
| - func: lu.WrappedFun |
150 |
| - name: Optional[str] |
151 |
| - grid: Optional[Grid] |
152 |
| - in_specs: Optional[list[Optional[BlockSpec]]] |
153 |
| - out_specs: Optional[list[Optional[BlockSpec]]] |
154 |
| - |
155 |
| - def __post_init__(self): |
156 |
| - if self.grid is None: |
157 |
| - if self.in_specs is not None: |
158 |
| - raise ValueError("Cannot specify `in_specs` with a `None` grid.") |
159 |
| - if self.out_specs is not None: |
160 |
| - raise ValueError("Cannot specify `out_specs` with a `None` grid.") |
161 |
| - |
162 |
| - def get_name(self) -> str: |
163 |
| - return extract_function_name(self.func, self.name) |
164 |
| - |
165 |
| - def specialize(self, |
166 |
| - in_avals: tuple[AbstractRef, ...], |
167 |
| - out_avals: tuple[AbstractRef, ...], |
168 |
| - in_tree: tree_util.PyTreeDef |
169 |
| - ) -> tuple[SpecializedKernel, ...]: |
170 |
| - grid = preprocess_grid(self.grid) |
171 |
| - in_specs = self.in_specs |
172 |
| - out_specs = self.out_specs |
173 |
| - if out_specs is not None and not isinstance(out_specs, (tuple, list)): |
174 |
| - out_specs = (out_specs,) |
175 |
| - if out_specs is not None and not isinstance(out_specs, tuple): |
176 |
| - out_specs = tuple(out_specs) |
177 |
| - if in_specs is None: |
178 |
| - in_specs = [None] * len(in_avals) |
179 |
| - if out_specs is None: |
180 |
| - out_specs = [None] * len(out_avals) |
181 |
| - if grid == (): |
182 |
| - in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) |
183 |
| - for arg in in_avals] |
184 |
| - out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) |
185 |
| - for arg in out_avals] |
186 |
| - else: |
187 |
| - in_ref_avals = [ |
188 |
| - state.shaped_array_ref( |
189 |
| - compute_shape_from_block_spec(block_spec, aval.shape), |
190 |
| - aval.dtype) |
191 |
| - for block_spec, aval in zip(in_specs, in_avals)] |
192 |
| - out_ref_avals = [ |
193 |
| - state.shaped_array_ref( |
194 |
| - compute_shape_from_block_spec(block_spec, aval.shape), |
195 |
| - aval.dtype) |
196 |
| - for block_spec, aval in zip(out_specs, out_avals)] |
197 |
| - in_block_mappings = map(partial(convert_block_spec_to_block_mapping, grid), |
198 |
| - in_specs) |
199 |
| - out_block_mappings = map(partial(convert_block_spec_to_block_mapping, grid), |
200 |
| - out_specs) |
201 |
| - grid_spec = GridSpec(grid, (*in_block_mappings, *out_block_mappings), ()) |
202 |
| - jaxpr, consts, out_tree = _initial_style_open_jaxpr( |
203 |
| - self.func, in_tree, tuple((*in_ref_avals, *out_ref_avals))) |
204 |
| - return [SpecializedKernel(jaxpr, grid_spec)], consts, out_tree |
| 145 | + compiler_params: dict[str, Any] |
0 commit comments