@@ -103,19 +103,22 @@ def body(i, refs):
103
103
acc = acc .astype (o_ref .dtype )
104
104
pl .store (o_ref , (pl .dslice (start_q * block_q , block_q ), pl .dslice (None )), acc )
105
105
106
- @functools .partial (jax .custom_vjp , nondiff_argnums = [3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ])
106
+ @functools .partial (jax .custom_vjp , nondiff_argnums = [3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 ])
107
107
@functools .partial (jax .jit , static_argnames = ["sm_scale" , "block_q" , "block_k" ,
108
+ "backward_pass_impl" ,
108
109
"num_warps" , "num_stages" , "grid" ,
109
110
"interpret" , "debug" ])
110
111
def mha (q , k , v ,
111
112
sm_scale : float = 1.0 ,
112
113
block_q : int = 128 ,
113
114
block_k : int = 128 ,
115
+ backward_pass_impl : str = "triton" ,
114
116
num_warps : Optional [int ] = None ,
115
117
num_stages : int = 1 ,
116
118
grid = None ,
117
119
interpret : bool = False ,
118
120
debug : bool = False ):
121
+ del backward_pass_impl
119
122
batch_size , seq_len , num_heads , head_dim = q .shape
120
123
block_q = min (block_q , seq_len )
121
124
block_k = min (block_k , seq_len )
@@ -156,8 +159,10 @@ def mha(q, k, v,
156
159
return out
157
160
158
161
def _mha_forward (q , k , v , sm_scale : float , block_q : int , block_k : int ,
162
+ backward_pass_impl : str ,
159
163
num_warps : Optional [int ], num_stages : int , grid : Any ,
160
164
interpret : bool , debug : bool ):
165
+ del backward_pass_impl
161
166
batch_size , seq_len , num_heads , head_dim = q .shape
162
167
block_q = min (block_q , seq_len )
163
168
block_k = min (block_k , seq_len )
@@ -257,7 +262,7 @@ def mha_backward_kernel(
257
262
* , sm_scale : float ,
258
263
block_q : int , block_d : int , block_k : int
259
264
):
260
- del out_ref , l_ref # Not needed
265
+ del out_ref , l_ref # Not needed
261
266
seq_len = q_ref .shape [0 ]
262
267
263
268
def outer_loop (start_k , _ ):
@@ -298,53 +303,60 @@ def inner_loop(start_q, refs):
298
303
slice (None )), dk .astype (dk_ref .dtype ))
299
304
for_loop (jt .cdiv (seq_len , block_k ), outer_loop , ())
300
305
301
- def _mha_backward (sm_scale : float , block_q : int , block_k : int , num_warps :
302
- Optional [int ], num_stages : int , grid : Any , interpret : bool ,
306
+ def _mha_backward (sm_scale : float , block_q : int , block_k : int ,
307
+ backward_pass_impl : str , num_warps : Optional [int ],
308
+ num_stages : int , grid : Any , interpret : bool ,
303
309
debug : bool , res , do ):
304
310
del num_warps , num_stages , grid
305
311
q , k , v , out , l , m = res
312
+
306
313
batch_size , seq_len , num_heads , head_dim = q .shape
307
314
block_q = min (block_q , seq_len )
308
315
block_k = min (block_k , seq_len )
309
316
do_scaled , delta = _preprocess_backward (out , do , l , block_q , debug , interpret )
310
- # We accumulate into dq so we need to initialize it to zeros.
311
- dq = jnp .zeros (q .shape , jnp .float32 )
312
317
313
- out_shapes = [
314
- jax .ShapeDtypeStruct (dq .shape , dq .dtype ),
315
- jax .ShapeDtypeStruct (k .shape , k .dtype ),
316
- jax .ShapeDtypeStruct (v .shape , v .dtype ),
317
- ]
318
+ if backward_pass_impl == "xla" :
319
+ return jax .vjp (mha_reference , q , k , v )[1 ](do )
320
+ elif backward_pass_impl == "triton" :
321
+ # We accumulate into dq so we need to initialize it to zeros.
322
+ dq = jnp .zeros (q .shape , jnp .float32 )
323
+ out_shapes = [
324
+ jax .ShapeDtypeStruct (dq .shape , dq .dtype ),
325
+ jax .ShapeDtypeStruct (k .shape , k .dtype ),
326
+ jax .ShapeDtypeStruct (v .shape , v .dtype ),
327
+ ]
318
328
319
- grid = (batch_size , num_heads )
320
- num_warps = 8
321
- dq , dk , dv = pl .pallas_call (
322
- functools .partial (mha_backward_kernel , block_q = block_q , block_d = head_dim ,
323
- block_k = block_k , sm_scale = sm_scale ),
324
- grid = grid ,
325
- out_shape = out_shapes ,
326
- in_specs = [
327
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
328
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
329
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
330
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
331
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
332
- pl .BlockSpec (lambda j , k : (j , k , 0 ), (None , None , seq_len )),
333
- pl .BlockSpec (lambda j , k : (j , k , 0 ), (None , None , seq_len )),
334
- pl .BlockSpec (lambda j , k : (j , k , 0 ), (None , None , seq_len )),
335
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
336
- ],
337
- out_specs = [
338
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
339
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
340
- pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
341
- ],
342
- name = "mha_backward" ,
343
- debug = debug ,
344
- interpret = interpret ,
345
- num_warps = num_warps ,
346
- num_stages = 1 ,
347
- input_output_aliases = {8 : 0 })(q , k , v , out , do_scaled , l , m , delta , dq )
329
+ grid = (batch_size , num_heads )
330
+ num_warps = 8
331
+ dq , dk , dv = pl .pallas_call (
332
+ functools .partial (mha_backward_kernel , block_q = block_q , block_d = head_dim ,
333
+ block_k = block_k , sm_scale = sm_scale ),
334
+ grid = grid ,
335
+ out_shape = out_shapes ,
336
+ in_specs = [
337
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
338
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
339
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
340
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
341
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
342
+ pl .BlockSpec (lambda j , k : (j , k , 0 ), (None , None , seq_len )),
343
+ pl .BlockSpec (lambda j , k : (j , k , 0 ), (None , None , seq_len )),
344
+ pl .BlockSpec (lambda j , k : (j , k , 0 ), (None , None , seq_len )),
345
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
346
+ ],
347
+ out_specs = [
348
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
349
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
350
+ pl .BlockSpec (lambda j , k : (j , 0 , k , 0 ), (None , seq_len , None , head_dim )),
351
+ ],
352
+ name = "mha_backward" ,
353
+ debug = debug ,
354
+ interpret = interpret ,
355
+ num_warps = num_warps ,
356
+ num_stages = 1 ,
357
+ input_output_aliases = {8 : 0 })(q , k , v , out , do_scaled , l , m , delta , dq )
358
+ else :
359
+ raise ValueError (f"Invalid backward pass implementation: { backward_pass_impl } " )
348
360
return dq .astype (q .dtype ), dk , dv
349
361
mha .defvjp (_mha_forward , _mha_backward )
350
362
0 commit comments