-
Notifications
You must be signed in to change notification settings - Fork 51
A faster flash attention bwd implementation #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Decompose the bwd kernel into two kernels, one for dq and one for dk,dv. - Extra parallelism over the sequence length axis. - On a benchmark, it is 4X faster compared to the previous implementation. 2X faster than XLA bwd pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment: the current backward pass is a fully fused kernel that parallelizes over batch * num heads number of threads.
For attention shapes that have small batch and heads (as is common in language model training) this kernel will underutilize the GPU.
However, there are applications where this kernel is faster than the two kernel variant.
Could you add the two kernel version as a separate backward pass impl, that way the user has the option of selecting the one they want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add tests into pallas_test.py?
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), | ||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), | ||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), | ||
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename to be (i, j, _)? Same below?
upper_bound = jt.cdiv(seq_len, block_k) | ||
dq = lax.fori_loop(0, upper_bound, inner_loop, dq) | ||
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), | ||
slice(None)), dq, eviction_policy="evict_last") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need eviction policy here
slice(None)), dv.astype(dv_ref.dtype)) | ||
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), | ||
slice(None)), dk.astype(dk_ref.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: indentation
@@ -346,6 +450,65 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, | |||
num_warps=num_warps, | |||
num_stages=1, | |||
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) | |||
elif backward_pass_impl == "triton_split": | |||
# We accumulate into dq so we need to initialize it to zeros. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is not accurate here
@@ -346,6 +450,65 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, | |||
num_warps=num_warps, | |||
num_stages=1, | |||
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) | |||
elif backward_pass_impl == "triton_split": | |||
# We accumulate into dq so we need to initialize it to zeros. | |||
out_shapes_q = jax.ShapeDtypeStruct(q.shape, jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect we don't need dq to be f32 anymore. Could you try q.dtype?
@sharadmv Can this PR be merged? We see a big performance improvement on NVIDIA A100 GPUs with this PR. |
I left some comments. @tonywu95 do you have time to address them? |
Hey @tonywu95, is it ok if we take over this PR and put you as a co-author? We'd love to get it in! |
Uh oh!
There was an error while loading. Please reload this page.