Skip to content

Commit e8137b8

Browse files
author
The jax_triton Authors
committed
Merge pull request #78 from jax-ml:backwards-xla
PiperOrigin-RevId: 508552584
2 parents e3a1931 + 88b0f4d commit e8137b8

File tree

2 files changed

+56
-44
lines changed

2 files changed

+56
-44
lines changed

jax_triton/pallas/ops/attention.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,22 @@ def body(i, refs):
103103
acc = acc.astype(o_ref.dtype)
104104
pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc)
105105

106-
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10])
106+
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10, 11])
107107
@functools.partial(jax.jit, static_argnames=["sm_scale", "block_q", "block_k",
108+
"backward_pass_impl",
108109
"num_warps", "num_stages", "grid",
109110
"interpret", "debug"])
110111
def mha(q, k, v,
111112
sm_scale: float = 1.0,
112113
block_q: int = 128,
113114
block_k: int = 128,
115+
backward_pass_impl: str = "triton",
114116
num_warps: Optional[int] = None,
115117
num_stages: int = 1,
116118
grid=None,
117119
interpret: bool = False,
118120
debug: bool = False):
121+
del backward_pass_impl
119122
batch_size, seq_len, num_heads, head_dim = q.shape
120123
block_q = min(block_q, seq_len)
121124
block_k = min(block_k, seq_len)
@@ -156,8 +159,10 @@ def mha(q, k, v,
156159
return out
157160

158161
def _mha_forward(q, k, v, sm_scale: float, block_q: int, block_k: int,
162+
backward_pass_impl: str,
159163
num_warps: Optional[int], num_stages: int, grid: Any,
160164
interpret: bool, debug: bool):
165+
del backward_pass_impl
161166
batch_size, seq_len, num_heads, head_dim = q.shape
162167
block_q = min(block_q, seq_len)
163168
block_k = min(block_k, seq_len)
@@ -257,7 +262,7 @@ def mha_backward_kernel(
257262
*, sm_scale: float,
258263
block_q: int, block_d: int, block_k: int
259264
):
260-
del out_ref, l_ref # Not needed
265+
del out_ref, l_ref # Not needed
261266
seq_len = q_ref.shape[0]
262267

263268
def outer_loop(start_k, _):
@@ -298,53 +303,60 @@ def inner_loop(start_q, refs):
298303
slice(None)), dk.astype(dk_ref.dtype))
299304
for_loop(jt.cdiv(seq_len, block_k), outer_loop, ())
300305

301-
def _mha_backward(sm_scale: float, block_q: int, block_k: int, num_warps:
302-
Optional[int], num_stages: int, grid: Any, interpret: bool,
306+
def _mha_backward(sm_scale: float, block_q: int, block_k: int,
307+
backward_pass_impl: str, num_warps: Optional[int],
308+
num_stages: int, grid: Any, interpret: bool,
303309
debug: bool, res, do):
304310
del num_warps, num_stages, grid
305311
q, k, v, out, l, m = res
312+
306313
batch_size, seq_len, num_heads, head_dim = q.shape
307314
block_q = min(block_q, seq_len)
308315
block_k = min(block_k, seq_len)
309316
do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
310-
# We accumulate into dq so we need to initialize it to zeros.
311-
dq = jnp.zeros(q.shape, jnp.float32)
312317

313-
out_shapes = [
314-
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
315-
jax.ShapeDtypeStruct(k.shape, k.dtype),
316-
jax.ShapeDtypeStruct(v.shape, v.dtype),
317-
]
318+
if backward_pass_impl == "xla":
319+
return jax.vjp(mha_reference, q, k, v)[1](do)
320+
elif backward_pass_impl == "triton":
321+
# We accumulate into dq so we need to initialize it to zeros.
322+
dq = jnp.zeros(q.shape, jnp.float32)
323+
out_shapes = [
324+
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
325+
jax.ShapeDtypeStruct(k.shape, k.dtype),
326+
jax.ShapeDtypeStruct(v.shape, v.dtype),
327+
]
318328

319-
grid = (batch_size, num_heads)
320-
num_warps = 8
321-
dq, dk, dv = pl.pallas_call(
322-
functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim,
323-
block_k=block_k, sm_scale=sm_scale),
324-
grid=grid,
325-
out_shape=out_shapes,
326-
in_specs=[
327-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
328-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
329-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
330-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
331-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
332-
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
333-
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
334-
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
335-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
336-
],
337-
out_specs=[
338-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
339-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
340-
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
341-
],
342-
name="mha_backward",
343-
debug=debug,
344-
interpret=interpret,
345-
num_warps=num_warps,
346-
num_stages=1,
347-
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
329+
grid = (batch_size, num_heads)
330+
num_warps = 8
331+
dq, dk, dv = pl.pallas_call(
332+
functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim,
333+
block_k=block_k, sm_scale=sm_scale),
334+
grid=grid,
335+
out_shape=out_shapes,
336+
in_specs=[
337+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
338+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
339+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
340+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
341+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
342+
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
343+
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
344+
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
345+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
346+
],
347+
out_specs=[
348+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
349+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
350+
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
351+
],
352+
name="mha_backward",
353+
debug=debug,
354+
interpret=interpret,
355+
num_warps=num_warps,
356+
num_stages=1,
357+
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
358+
else:
359+
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
348360
return dq.astype(q.dtype), dk, dv
349361
mha.defvjp(_mha_forward, _mha_backward)
350362

jax_triton/pallas/primitives.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ def _load_pp_rule(eqn, context, settings):
326326
idx, *masked_other = tree_util.tree_unflatten(eqn.params["args_tree"], args)
327327
idx = _pp_idx(eqn.invars[0].aval, idx, context)
328328
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
329-
return [lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
329+
return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
330330
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']')
331-
]))]
331+
]))])
332332
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
333333

334334
def _load_jvp(primals, tangents, *, args_tree, masked, **params: Any):
@@ -400,9 +400,9 @@ def _swap_pp_rule(eqn, context, settings):
400400
idx = _pp_idx(eqn.invars[0].aval, idx, context)
401401
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
402402
if isinstance(y, jax_core.DropVar):
403-
return [state_primitives.pp_ref(pp.concat([
403+
return pp.concat([state_primitives.pp_ref(pp.concat([
404404
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']'),
405-
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]))]
405+
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))]))])
406406
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
407407

408408
def _swap_jvp(primals, tangents, *, args_tree, masked, **params: Any):

0 commit comments

Comments
 (0)